diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index 8cff1849aa5..84ac562cc3c 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -167,7 +167,7 @@ def _maybe_update_prefix_cache( self.omni_prefix_cache.update_omni_tensor_prefix_cache( hidden_states=hidden_states, - multimodal_outputs=multimodal_outputs, + multimodal_outputs=flatten_payload(multimodal_outputs) if multimodal_outputs else multimodal_outputs, num_tokens_unpadded=num_tokens_unpadded, slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu, num_tokens_padded=num_tokens_padded, @@ -190,7 +190,7 @@ def _maybe_get_combined_prefix_cache_tensors( combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states( query_start_loc=self.query_start_loc.cpu, input_batch=self.input_batch, - multimodal_outputs=multimodal_outputs, + multimodal_outputs=flatten_payload(multimodal_outputs) if multimodal_outputs else multimodal_outputs, num_scheduled_tokens=num_scheduled_tokens, ) return combined_hidden_states, combined_multimodal_outputs @@ -944,12 +944,23 @@ def propose_draft_token_ids(sampled_token_ids): mm_payload: dict[str, object] = {} if combined_multimodal_outputs or mm_cpu: if combined_multimodal_outputs: + # Prefix cache enabled; all items have already been processed + # and split apart for each request as needed, and all tensors + # have already been detached to the CPU. Lists are kept as + # passthrough data for consistent behavior in postprocess. + # Recurse into nested dicts so list-valued sub-keys (e.g. + # embed.tts_bos = [tensor]) are unwrapped to bare tensors + # at the leaves; downstream flatten_payload then yields a + # wire-clean dict[str, torch.Tensor]. + def _unwrap_lists(v): + if isinstance(v, list): + return v[idx] if idx < len(v) else v[0] + if isinstance(v, dict): + return {k: _unwrap_lists(sv) for k, sv in v.items()} + return v + for mm_key in combined_multimodal_outputs.keys(): - value = combined_multimodal_outputs[mm_key][rid] - if isinstance(value, list): - mm_payload[mm_key] = value[idx] if idx < len(value) else value[0] - else: - mm_payload[mm_key] = value + mm_payload[mm_key] = _unwrap_lists(combined_multimodal_outputs[mm_key][rid]) else: for mm_key, mm_val in mm_cpu.items(): mm_payload[mm_key] = to_payload_element(