Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _maybe_update_prefix_cache(

self.omni_prefix_cache.update_omni_tensor_prefix_cache(
hidden_states=hidden_states,
multimodal_outputs=multimodal_outputs,
multimodal_outputs=flatten_payload(multimodal_outputs) if multimodal_outputs else multimodal_outputs,
num_tokens_unpadded=num_tokens_unpadded,
slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu,
num_tokens_padded=num_tokens_padded,
Expand All @@ -190,7 +190,7 @@ def _maybe_get_combined_prefix_cache_tensors(
combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states(
query_start_loc=self.query_start_loc.cpu,
input_batch=self.input_batch,
multimodal_outputs=multimodal_outputs,
multimodal_outputs=flatten_payload(multimodal_outputs) if multimodal_outputs else multimodal_outputs,
num_scheduled_tokens=num_scheduled_tokens,
)
return combined_hidden_states, combined_multimodal_outputs
Expand Down Expand Up @@ -944,12 +944,23 @@ def propose_draft_token_ids(sampled_token_ids):
mm_payload: dict[str, object] = {}
if combined_multimodal_outputs or mm_cpu:
if combined_multimodal_outputs:
# Prefix cache enabled; all items have already been processed
# and split apart for each request as needed, and all tensors
# have already been detached to the CPU. Lists are kept as
# passthrough data for consistent behavior in postprocess.
# Recurse into nested dicts so list-valued sub-keys (e.g.
# embed.tts_bos = [tensor]) are unwrapped to bare tensors
# at the leaves; downstream flatten_payload then yields a
# wire-clean dict[str, torch.Tensor].
def _unwrap_lists(v):
if isinstance(v, list):
return v[idx] if idx < len(v) else v[0]
if isinstance(v, dict):
return {k: _unwrap_lists(sv) for k, sv in v.items()}
return v

for mm_key in combined_multimodal_outputs.keys():
value = combined_multimodal_outputs[mm_key][rid]
if isinstance(value, list):
mm_payload[mm_key] = value[idx] if idx < len(value) else value[0]
else:
mm_payload[mm_key] = value
mm_payload[mm_key] = _unwrap_lists(combined_multimodal_outputs[mm_key][rid])
else:
for mm_key, mm_val in mm_cpu.items():
mm_payload[mm_key] = to_payload_element(
Expand Down
Loading