diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md index 81499118623..6216b4b7859 100644 --- a/docs/contributing/model/adding_omni_model.md +++ b/docs/contributing/model/adding_omni_model.md @@ -460,7 +460,7 @@ def forward(self, ...): These keys are then accessible in your stage transition function: ```python # In stage_input_processors/qwen3_omni.py -thinker_embeddings = output.multimodal_output["0"] # Access by key +thinker_prefill_embeddings = output.multimodal_output["0"] # Access by key thinker_hidden_states = output.multimodal_output["24"] ``` @@ -513,11 +513,11 @@ def thinker2talker( for thinker_output in thinker_outputs: output = thinker_output.outputs[0] # Extract thinker embeddings and hidden states - thinker_embeddings = output.multimodal_output["0"].float().clone().detach().cuda() + thinker_prefill_embeddings = output.multimodal_output["0"].float().clone().detach().cuda() thinker_hidden_states = output.multimodal_output["24"].float().clone().detach().cuda() info = { - "thinker_embeddings": thinker_embeddings, + "thinker_prefill_embeddings": thinker_prefill_embeddings, "thinker_hidden_states": thinker_hidden_states, "thinker_sequences": thinker_output.prompt_token_ids + output.token_ids, "thinker_input_ids": thinker_output.prompt_token_ids, diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index d80323bbdec..c7289d8b846 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -185,9 +185,12 @@ def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> self.request_payload[req_id] = payload_data return payload_data origin_payload = self.request_payload[req_id] + override_keys = payload_data.pop("override_keys", []) for key, value in payload_data.items(): if key == "finished": continue + elif key in override_keys: + payload_data[key] = value elif isinstance(value, torch.Tensor) and key in origin_payload: payload_data[key] = torch.cat([origin_payload[key], value], dim=0) elif isinstance(value, list) and key in origin_payload: diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 7506c117d0a..ec928195ab7 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -603,7 +603,7 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, else: # decode if not info_dict.get("decode_flag", False): - info_dict["num_processed_tokens"] = len(info_dict.get("thinker_input_ids", [])) + 1 + info_dict["num_processed_tokens"] = 0 update_dict["decode_flag"] = True last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode( @@ -673,7 +673,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch start_index = info_dict.get("num_processed_tokens", 0) end_index = start_index + input_embeds.shape[0] # Read thinker outputs for prefill - thinker_sequence_embeds = info_dict.get("thinker_embeddings").to( + thinker_sequence_embeds = info_dict.get("thinker_prefill_embeddings").to( device=self._module_device(self.talker), dtype=torch.bfloat16 ) # Tensor [P,H] thinker_hidden_states = info_dict.get("thinker_hidden_states").to( @@ -703,7 +703,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch if thinker_sequence_embeds is None or thinker_hidden_states is None: raise ValueError( "additional_information_by_req_id must include " - "'thinker_embeddings' and 'thinker_hidden_states' for talker prefill." + "'thinker_prefill_embeddings' and 'thinker_hidden_states' for talker prefill." ) # Normalize to tensors @@ -770,9 +770,35 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch update_dict["tts_pad_embed_projected"] = pad_proj.detach().to("cpu").contiguous() except Exception: pass + self._talker_cache_thinker_decode_embeds(info_dict, update_dict) return req_input_ids[start_index:end_index], req_embeds[start_index:end_index], update_dict + def _talker_cache_thinker_decode_embeds( + self, + info_dict: dict[str, Any], + update_dict: dict[str, Any], + ) -> None: + """ + Cache thinker embeds for decode stage. + """ + thinker_decode_embeds = info_dict.get("thinker_decode_embeddings", None) + if thinker_decode_embeds is not None: + cached_thinker_decode_embeds = info_dict.get("cached_thinker_decode_embeddings", None) + if cached_thinker_decode_embeds is None: + update_dict["cached_thinker_decode_embeddings"] = thinker_decode_embeds + else: + cached_thinker_decode_embeds = cached_thinker_decode_embeds.to( + device=self._module_device(self.talker), dtype=torch.bfloat16 + ) + thinker_decode_embeds = thinker_decode_embeds.to( + device=self._module_device(self.talker), dtype=torch.bfloat16 + ) + update_dict["cached_thinker_decode_embeddings"] = torch.cat( + [cached_thinker_decode_embeds, thinker_decode_embeds], dim=0 + ) + update_dict["thinker_decode_embeddings"] = None + def _thinker_to_talker_prefill( self, thinker_embed: torch.Tensor, @@ -866,15 +892,25 @@ def _thinker_decode_to_talker_decode( Returns: (input_ids, input_embeds) for talker """ - thinker_embed = info_dict.get("thinker_embeddings", None) + cached_thinker_decode_embeds = info_dict.get("cached_thinker_decode_embeddings", None) + thinker_decode_embed = info_dict.get("thinker_decode_embeddings", None) start_index = info_dict.get("num_processed_tokens", 0) - if start_index >= thinker_embed.shape[0]: + thinker_output_token_ids = info_dict.get("thinker_output_token_ids", []) + if start_index >= len(thinker_output_token_ids) - 1: if info_dict.get("finished_flag"): return self.tts_pad_embed.to(device) update_dict["finished_flag"] = True return self.tts_eos_embed.to(device) - - thinker_embed = thinker_embed[start_index : start_index + 1].to(device) + if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]: + cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device) + thinker_embed = cached_thinker_decode_embeds[start_index] + if thinker_decode_embed is not None: + thinker_decode_embed = thinker_decode_embed.to(device) + cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0) + update_dict["cached_thinker_decode_embeddings"] = cached_thinker_decode_embeds + else: + thinker_embed = thinker_decode_embed.to(device) + update_dict["thinker_decode_embeddings"] = None return self.talker.text_projection(thinker_embed).to(device) def talker_preprocess_decode( diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 31a5ddd47f5..7cfc59f79c2 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -105,7 +105,7 @@ def thinker2talker_async_chunk( all_token_ids = _ensure_list(all_token_ids) prompt_token_ids = _ensure_list(prompt_token_ids) talker_additional_info = { - "thinker_embeddings": pooling_output.get("0").detach().cpu(), + "thinker_prefill_embeddings": pooling_output.get("0").detach().cpu(), "thinker_hidden_states": pooling_output.get("24").detach().cpu(), "thinker_sequences": all_token_ids, "thinker_input_ids": prompt_token_ids, @@ -121,8 +121,12 @@ def thinker2talker_async_chunk( return None else: save_payload = transfer_manager.request_payload.pop(request_id) - talker_additional_info["thinker_embeddings"] = torch.cat( - (save_payload.get("thinker_embeddings"), talker_additional_info.get("thinker_embeddings")), dim=0 + talker_additional_info["thinker_prefill_embeddings"] = torch.cat( + ( + save_payload.get("thinker_prefill_embeddings"), + talker_additional_info.get("thinker_prefill_embeddings"), + ), + dim=0, ) talker_additional_info["thinker_hidden_states"] = torch.cat( (save_payload.get("thinker_hidden_states"), talker_additional_info.get("thinker_hidden_states")), @@ -134,12 +138,15 @@ def thinker2talker_async_chunk( output_token_ids = _ensure_list(output_token_ids) talker_additional_info = { - "thinker_embeddings": pooling_output.get("0").detach().cpu(), "finished": torch.tensor(is_finished, dtype=torch.bool), } - - if not output_token_ids: + if output_token_ids: + talker_additional_info["override_keys"] = ["thinker_decode_embeddings", "thinker_output_token_ids"] + talker_additional_info["thinker_decode_embeddings"] = pooling_output.get("0").detach().cpu() + talker_additional_info["thinker_output_token_ids"] = output_token_ids + else: # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. + talker_additional_info["thinker_prefill_embeddings"] = pooling_output.get("0").detach().cpu() talker_additional_info["thinker_hidden_states"] = pooling_output.get("24").detach().cpu() return talker_additional_info @@ -177,7 +184,7 @@ def thinker2talker( output = thinker_output.outputs[0] info = { - "thinker_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float), + "thinker_prefill_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float), "thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float), "thinker_sequences": ( thinker_output.prompt_token_ids + output.token_ids