Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/contributing/model/adding_omni_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def forward(self, ...):
These keys are then accessible in your stage transition function:
```python
# In stage_input_processors/qwen3_omni.py
thinker_embeddings = output.multimodal_output["0"] # Access by key
thinker_prefill_embeddings = output.multimodal_output["0"] # Access by key
thinker_hidden_states = output.multimodal_output["24"]
```

Expand Down Expand Up @@ -513,11 +513,11 @@ def thinker2talker(
for thinker_output in thinker_outputs:
output = thinker_output.outputs[0]
# Extract thinker embeddings and hidden states
thinker_embeddings = output.multimodal_output["0"].float().clone().detach().cuda()
thinker_prefill_embeddings = output.multimodal_output["0"].float().clone().detach().cuda()
thinker_hidden_states = output.multimodal_output["24"].float().clone().detach().cuda()

info = {
"thinker_embeddings": thinker_embeddings,
"thinker_prefill_embeddings": thinker_prefill_embeddings,
"thinker_hidden_states": thinker_hidden_states,
"thinker_sequences": thinker_output.prompt_token_ids + output.token_ids,
"thinker_input_ids": thinker_output.prompt_token_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,12 @@ def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) ->
self.request_payload[req_id] = payload_data
return payload_data
origin_payload = self.request_payload[req_id]
override_keys = payload_data.pop("override_keys", [])
for key, value in payload_data.items():
if key == "finished":
continue
elif key in override_keys:
payload_data[key] = value
elif isinstance(value, torch.Tensor) and key in origin_payload:
payload_data[key] = torch.cat([origin_payload[key], value], dim=0)
elif isinstance(value, list) and key in origin_payload:
Expand Down
50 changes: 43 additions & 7 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor,
else:
# decode
if not info_dict.get("decode_flag", False):
info_dict["num_processed_tokens"] = len(info_dict.get("thinker_input_ids", [])) + 1
info_dict["num_processed_tokens"] = 0
update_dict["decode_flag"] = True

last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode(
Expand Down Expand Up @@ -673,7 +673,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
start_index = info_dict.get("num_processed_tokens", 0)
end_index = start_index + input_embeds.shape[0]
# Read thinker outputs for prefill
thinker_sequence_embeds = info_dict.get("thinker_embeddings").to(
thinker_sequence_embeds = info_dict.get("thinker_prefill_embeddings").to(
device=self._module_device(self.talker), dtype=torch.bfloat16
) # Tensor [P,H]
thinker_hidden_states = info_dict.get("thinker_hidden_states").to(
Expand Down Expand Up @@ -703,7 +703,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
if thinker_sequence_embeds is None or thinker_hidden_states is None:
raise ValueError(
"additional_information_by_req_id must include "
"'thinker_embeddings' and 'thinker_hidden_states' for talker prefill."
"'thinker_prefill_embeddings' and 'thinker_hidden_states' for talker prefill."
)

# Normalize to tensors
Expand Down Expand Up @@ -770,9 +770,35 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
update_dict["tts_pad_embed_projected"] = pad_proj.detach().to("cpu").contiguous()
except Exception:
pass
self._talker_cache_thinker_decode_embeds(info_dict, update_dict)

return req_input_ids[start_index:end_index], req_embeds[start_index:end_index], update_dict

def _talker_cache_thinker_decode_embeds(
self,
info_dict: dict[str, Any],
update_dict: dict[str, Any],
) -> None:
"""
Cache thinker embeds for decode stage.
"""
thinker_decode_embeds = info_dict.get("thinker_decode_embeddings", None)
if thinker_decode_embeds is not None:
cached_thinker_decode_embeds = info_dict.get("cached_thinker_decode_embeddings", None)
if cached_thinker_decode_embeds is None:
update_dict["cached_thinker_decode_embeddings"] = thinker_decode_embeds
else:
cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(
device=self._module_device(self.talker), dtype=torch.bfloat16
)
thinker_decode_embeds = thinker_decode_embeds.to(
device=self._module_device(self.talker), dtype=torch.bfloat16
)
update_dict["cached_thinker_decode_embeddings"] = torch.cat(
[cached_thinker_decode_embeds, thinker_decode_embeds], dim=0
)
update_dict["thinker_decode_embeddings"] = None

def _thinker_to_talker_prefill(
self,
thinker_embed: torch.Tensor,
Expand Down Expand Up @@ -866,15 +892,25 @@ def _thinker_decode_to_talker_decode(
Returns:
(input_ids, input_embeds) for talker
"""
thinker_embed = info_dict.get("thinker_embeddings", None)
cached_thinker_decode_embeds = info_dict.get("cached_thinker_decode_embeddings", None)
thinker_decode_embed = info_dict.get("thinker_decode_embeddings", None)
start_index = info_dict.get("num_processed_tokens", 0)
if start_index >= thinker_embed.shape[0]:
thinker_output_token_ids = info_dict.get("thinker_output_token_ids", [])
if start_index >= len(thinker_output_token_ids) - 1:
if info_dict.get("finished_flag"):
return self.tts_pad_embed.to(device)
update_dict["finished_flag"] = True
return self.tts_eos_embed.to(device)

thinker_embed = thinker_embed[start_index : start_index + 1].to(device)
if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]:
cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device)
thinker_embed = cached_thinker_decode_embeds[start_index]
if thinker_decode_embed is not None:
thinker_decode_embed = thinker_decode_embed.to(device)
cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0)
update_dict["cached_thinker_decode_embeddings"] = cached_thinker_decode_embeds
else:
thinker_embed = thinker_decode_embed.to(device)
update_dict["thinker_decode_embeddings"] = None
return self.talker.text_projection(thinker_embed).to(device)

def talker_preprocess_decode(
Expand Down
21 changes: 14 additions & 7 deletions vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def thinker2talker_async_chunk(
all_token_ids = _ensure_list(all_token_ids)
prompt_token_ids = _ensure_list(prompt_token_ids)
talker_additional_info = {
"thinker_embeddings": pooling_output.get("0").detach().cpu(),
"thinker_prefill_embeddings": pooling_output.get("0").detach().cpu(),
"thinker_hidden_states": pooling_output.get("24").detach().cpu(),
"thinker_sequences": all_token_ids,
"thinker_input_ids": prompt_token_ids,
Expand All @@ -121,8 +121,12 @@ def thinker2talker_async_chunk(
return None
else:
save_payload = transfer_manager.request_payload.pop(request_id)
talker_additional_info["thinker_embeddings"] = torch.cat(
(save_payload.get("thinker_embeddings"), talker_additional_info.get("thinker_embeddings")), dim=0
talker_additional_info["thinker_prefill_embeddings"] = torch.cat(
(
save_payload.get("thinker_prefill_embeddings"),
talker_additional_info.get("thinker_prefill_embeddings"),
),
dim=0,
)
talker_additional_info["thinker_hidden_states"] = torch.cat(
(save_payload.get("thinker_hidden_states"), talker_additional_info.get("thinker_hidden_states")),
Expand All @@ -134,12 +138,15 @@ def thinker2talker_async_chunk(
output_token_ids = _ensure_list(output_token_ids)

talker_additional_info = {
"thinker_embeddings": pooling_output.get("0").detach().cpu(),
"finished": torch.tensor(is_finished, dtype=torch.bool),
}

if not output_token_ids:
if output_token_ids:
talker_additional_info["override_keys"] = ["thinker_decode_embeddings", "thinker_output_token_ids"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The override_keys are the same for each step, so we don't need to accumulate them or transmit them every time.

talker_additional_info["thinker_decode_embeddings"] = pooling_output.get("0").detach().cpu()
talker_additional_info["thinker_output_token_ids"] = output_token_ids
else:
# When prefilling a chunked thinker, thinker_hidden_states needs to be updated.
talker_additional_info["thinker_prefill_embeddings"] = pooling_output.get("0").detach().cpu()
talker_additional_info["thinker_hidden_states"] = pooling_output.get("24").detach().cpu()
Comment on lines +147 to 150
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve thinker embeddings on no-token async chunks

When output_token_ids is empty (the code comment says this happens while chunked thinker prefill is still running), this branch updates only thinker_hidden_states and never refreshes thinker_embeddings. The new decode path later computes start_index from num_processed_tokens - thinker_embeddings.shape[0] in _thinker_decode_to_talker_decode, so if prefill spans additional chunks after chunk 0, thinker_embeddings.shape[0] becomes stale and the talker can jump past available decode embeddings and emit EOS/pad early.

Useful? React with 👍 / 👎.

return talker_additional_info

Expand Down Expand Up @@ -177,7 +184,7 @@ def thinker2talker(
output = thinker_output.outputs[0]

info = {
"thinker_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float),
"thinker_prefill_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float),
"thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float),
"thinker_sequences": (
thinker_output.prompt_token_ids + output.token_ids
Expand Down