Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions examples/offline_inference/hunyuan_image3/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,23 @@ def build_prompt(
task: str = "it2i_think",
sys_type: str | None = None,
custom_system_prompt: str | None = None,
kv_reuse: bool = False,
) -> str:
"""Build a HunyuanImage-3.0 prompt using pretrain template format."""
"""Build a HunyuanImage-3.0 prompt for the AR stage.

The HunyuanImage-3.0-Instruct model's generation_config specifies
``sequence_template="instruct"``, so the AR prefill MUST use the
Instruct template in both the normal (non-KV-reuse) and the KV-reuse
paths. Otherwise the AR KV cache token layout drifts from the model's
training distribution and output quality degrades.

Format (both kv_reuse=False and kv_reuse=True):
<|startoftext|>{sys}\\n\\nUser: [<img>]{user_prompt}\\n\\nAssistant: [trigger_tag]

The ``kv_reuse`` argument is retained for API compatibility but no
longer changes the prompt layout. The trigger tag is part of the AR
prefill; it must NOT be stored separately and re-prepended by the DiT.
"""
if task not in _TASK_PRESETS:
raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}")

Expand All @@ -60,14 +75,18 @@ def build_prompt(

has_image_input = task.startswith("i2t") or task.startswith("it2i")

# Instruct template (matches the model's generation_config.sequence_template
# and the old prompt_utils.build_prompt layout).
parts = ["<|startoftext|>"]
if sys_text:
parts.append(sys_text)
parts.append("\n\nUser: ")
if has_image_input:
parts.append("<img>")
parts.append(user_prompt)
parts.append("\n\nAssistant: ")
if trigger_tag:
parts.append(trigger_tag)
parts.append(user_prompt)

return "".join(parts)

Expand Down Expand Up @@ -176,12 +195,29 @@ def main():

input_image = Image.open(args.image_path).convert("RGB")

# Detect KV-reuse stage configs. Note: the AR prompt layout is now the
# same Instruct template in both paths (see build_prompt); the flag is
# only used for informational purposes.
_kv_reuse_configs = {"hunyuan_image3_it2i_kv_reuse.yaml"}
kv_reuse = os.path.basename(stage_configs_path) in _kv_reuse_configs

# Format prompts
formatted_prompts: list[OmniPromptType] = []
for p in prompts:
formatted_text = build_prompt(p, task=task, sys_type=args.sys_type)

prompt_dict: dict = {"prompt": formatted_text}
formatted_text = build_prompt(p, task=task, sys_type=args.sys_type, kv_reuse=kv_reuse)
preset_sys_type, _, _trigger_tag = _TASK_PRESETS[task]
effective_sys_type = args.sys_type or preset_sys_type

prompt_dict: dict = {
"prompt": formatted_text,
"user_prompt": p,
"use_system_prompt": effective_sys_type,
}
# Note: under the Instruct template the trigger tag (<think>/<recaption>)
# is already included in the AR prefill, so we do NOT pass it separately
# to the DiT. The DiT reconstructs cot_text from the AR-generated text
# alone (which begins with the trigger tag's content, terminated by the
# matching closing tag like </think>).

if args.modality == "text2img":
prompt_dict["modalities"] = ["image"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for ``ImageKVCacheManager.inject_prompt_kv_cache``.

This is the core interface the KV-reuse PR adds: it takes AR-produced
text KV tensors and pre-populates the DiT's ``image_kv_cache_map`` so
subsequent denoising steps can be run with ``first_step=False``
without recomputing the text KV.

The contract we guard here:

* For a pos-only branch, the cached prefix length equals
``pos_len + num_special_tokens`` and the first ``pos_len`` rows of
the cached tensors equal the input positive tensors verbatim.
* For a pos+neg branch with different lengths, both branches are
zero-padded to ``max(pos_len, neg_len)`` so they share a single
``L_text`` slot, and the returned length equals
``max_len + num_special_tokens``.
* The trailing ``num_special_tokens`` + EOI slots are zero-filled --
this is the layout the attention mask in ``_forward_with_kv_reuse``
relies on (those positions are masked out of the softmax).
"""

import pytest
import torch

from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_image3_transformer import (
ImageKVCacheManager,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


def _mgr() -> ImageKVCacheManager:
# ImageKVCacheManager.__init__ queries the SP world-size, which is
# only initialised inside a distributed worker. ``inject_prompt_kv_cache``
# itself does not touch SP state, so bypass __init__ and stamp the
# attributes the method actually reads.
mgr = ImageKVCacheManager.__new__(ImageKVCacheManager)
mgr.num_heads = 4
mgr.num_kv_heads = 2
mgr.head_dim = 8
mgr.scaling = 1.0
mgr.image_token_len = 16
mgr.image_kv_cache_map = None
return mgr


def _rand_kv(length: int, kv_heads: int = 2, head_dim: int = 8):
k = torch.randn(length, kv_heads, head_dim)
v = torch.randn(length, kv_heads, head_dim)
return k, v


def test_inject_pos_only_length_and_layout():
mgr = _mgr()
pos_k, pos_v = _rand_kv(length=17)

cached_len = mgr.inject_prompt_kv_cache(pos_k, pos_v, num_special_tokens=3)

# Returned prefix length = text + special tokens (excludes eoi).
assert cached_len == 17 + 3

cached_k, cached_v = mgr.image_kv_cache_map
# Full cached layout = pos_text + special + eoi = 17 + 3 + 1 = 21.
assert cached_k.shape == (21, 2, 8)
assert cached_v.shape == (21, 2, 8)
# First pos_len rows preserved verbatim.
assert torch.equal(cached_k[:17], pos_k)
assert torch.equal(cached_v[:17], pos_v)
# The 3 special slots + 1 eoi slot must be zero (they are masked
# out of the softmax by the attention mask in _forward_with_kv_reuse).
assert torch.equal(cached_k[17:], torch.zeros(4, 2, 8))
assert torch.equal(cached_v[17:], torch.zeros(4, 2, 8))


def test_inject_pos_and_neg_same_length():
mgr = _mgr()
pos_k, pos_v = _rand_kv(length=10)
neg_k, neg_v = _rand_kv(length=10)

cached_len = mgr.inject_prompt_kv_cache(pos_k, pos_v, neg_k, neg_v, num_special_tokens=3)

assert cached_len == 10 + 3
cached_k, _ = mgr.image_kv_cache_map
# 2 * (text + special + eoi) = 2 * (10 + 3 + 1) = 28.
assert cached_k.shape == (28, 2, 8)
# Pos branch occupies rows [0:10], neg branch [14:24] (after
# 10 text + 3 special + 1 eoi).
assert torch.equal(cached_k[:10], pos_k)
assert torch.equal(cached_k[14:24], neg_k)


def test_inject_pos_and_neg_mismatched_length_pads_shorter_branch():
# This is the path that guards against the ``L_pos=6833, L_neg=1``
# degeneracy: whichever branch is shorter must be zero-padded up to
# ``L_text = max(L_pos, L_neg)`` and the returned length must
# reflect the padded max.
mgr = _mgr()
pos_k, pos_v = _rand_kv(length=12)
neg_k, neg_v = _rand_kv(length=7) # shorter

cached_len = mgr.inject_prompt_kv_cache(pos_k, pos_v, neg_k, neg_v, num_special_tokens=3)

assert cached_len == 12 + 3 # max_len + special

cached_k, cached_v = mgr.image_kv_cache_map
# Layout: [pos(12) | special(3) | eoi(1) | neg_padded(12) | special(3) | eoi(1)] = 32
assert cached_k.shape == (32, 2, 8)
# Positive branch preserved.
assert torch.equal(cached_k[:12], pos_k)
# Negative branch: first neg_len rows are the original neg_k, the
# remaining rows up to L_text=12 must be zero padding.
assert torch.equal(cached_k[16:16 + 7], neg_k)
assert torch.equal(cached_k[16 + 7:16 + 12], torch.zeros(5, 2, 8))
# Values follow the same layout.
assert torch.equal(cached_v[:12], pos_v)
assert torch.equal(cached_v[16:16 + 7], neg_v)


def test_inject_neg_longer_than_pos_also_pads():
# Symmetric case: explicit negative prompt longer than the positive
# one. Must pad pos, not neg.
mgr = _mgr()
pos_k, pos_v = _rand_kv(length=5)
neg_k, neg_v = _rand_kv(length=9)

cached_len = mgr.inject_prompt_kv_cache(pos_k, pos_v, neg_k, neg_v, num_special_tokens=3)

assert cached_len == 9 + 3
cached_k, _ = mgr.image_kv_cache_map
# 2 * (9 + 3 + 1) = 26
assert cached_k.shape == (26, 2, 8)
assert torch.equal(cached_k[:5], pos_k)
assert torch.equal(cached_k[5:9], torch.zeros(4, 2, 8)) # pos padding
assert torch.equal(cached_k[13:13 + 9], neg_k)


def test_inject_custom_num_special_tokens():
mgr = _mgr()
pos_k, pos_v = _rand_kv(length=4)

cached_len = mgr.inject_prompt_kv_cache(pos_k, pos_v, num_special_tokens=5)

assert cached_len == 4 + 5
cached_k, _ = mgr.image_kv_cache_map
assert cached_k.shape == (4 + 5 + 1, 2, 8)


def test_inject_preserves_dtype_and_device():
mgr = _mgr()
pos_k = torch.randn(6, 2, 8, dtype=torch.float16)
pos_v = torch.randn(6, 2, 8, dtype=torch.float16)

mgr.inject_prompt_kv_cache(pos_k, pos_v, num_special_tokens=3)

cached_k, cached_v = mgr.image_kv_cache_map
assert cached_k.dtype == torch.float16
assert cached_v.dtype == torch.float16
assert cached_k.device == pos_k.device
Loading
Loading