Skip to content

[Feature] HunyuanImage-3.0 AR->DiT KV-cache reuse for image editing (IT2I)#2949

Open
kechengliu97 wants to merge 1 commit into
vllm-project:mainfrom
kechengliu97:ar-dit
Open

[Feature] HunyuanImage-3.0 AR->DiT KV-cache reuse for image editing (IT2I)#2949
kechengliu97 wants to merge 1 commit into
vllm-project:mainfrom
kechengliu97:ar-dit

Conversation

@kechengliu97
Copy link
Copy Markdown
Contributor

@kechengliu97 kechengliu97 commented Apr 20, 2026

Summary

Enables the AR (language) stage to share its prefilled text KV cache with the
DiT (diffusion) stage on HunyuanImage-3.0-Instruct so the DiT no longer has to
re-encode the prompt from scratch. End-to-end denoise time for a 1216x832 IT2I
request drops from ~57 s to ~27 s on 4xL20X (TP=2 for both stages) with
visually indistinguishable output from the non-reuse baseline.

What this change does

Diffusion pipeline (pipeline_hunyuan_image3.py)

  • New _forward_with_kv_reuse() path that injects AR-produced K/V into each
    layer's ImageKVCacheManager and runs every denoising step as a non-first
    step (kv_injected=True / first_step=False).
  • Builds the correct [BOS|sys|user|cot] + [<boi>|<img_size>|<img_ratio>] + [<timestep>|img*N] + [<eoi>] token layout; zero-pads the three DiT special
    tokens the AR doesn't emit and masks them in attention.
  • Reads sequence_template from generation_config (defaults to "instruct"
    for HunyuanImage-3.0-Instruct) instead of hard-coding "pretrain", so the
    DiT text prefix matches the checkpoint's training distribution.

Transformer (hunyuan_image3_transformer.py)

  • ImageKVCacheManager.inject_prompt_kv_cache() prepares image_kv_cache_map
    in the exact layout _save_image_kv_caches() produces (pos/neg branches,
    special-token pad, eoi slot), so _update_image_kv_caches() works unchanged
    on every subsequent step.
  • forward(..., kv_injected=...) propagates the flag to every layer's
    attention call so the AR-produced text K/V is preserved across steps.

CFG companion (stage_input_processors/hunyuan_image3.py)

  • expand_cfg_prompts() now mirrors the positive prompt's structure (same
    system prompt, same image, same assistant/trigger) with only the user-text
    tokens replaced by <cfg>. This fixes the L_pos=6833 / L_neg=1
    degeneracy that produced visibly degraded images before (PSNR 6.5 dB;
    after fix PSNR ~10 dB and visually matching).
  • collect_cfg_kv_caches() retrieves the companion KV via
    OmniKVTransferManager and attaches it as
    sampling_params.cfg_text_past_key_values.
  • ar2diffusion() forwards ar_generated_text plus user metadata (system
    prompt, height/width, multi-modal data) to the DiT, and lazily decodes AR
    tokens via AutoTokenizer when detokenize: false on the AR stage leaves
    output.text empty -- this fixes the text length=0 / image ignored
    symptom reported on [Example] Add Hunyuan-Image3 end2end.py and README.md #2590.

Entry point (examples/offline_inference/hunyuan_image3/end2end.py)

  • Unified build_prompt() across all modalities using the Instruct chat
    template <|startoftext|>{sys}\n\nUser: [<img>]{q}\n\nAssistant: [trigger]; removes the earlier pretrain-vs-instruct split that silently
    drifted from the model's training distribution.
  • New img2img / img2text branches plumb multi-modal data and
    use_system_prompt through to both stages.

Stage configs

  • Adds hunyuan_image3_it2i_kv_reuse.yaml (the KV-reuse entry config) with
    need_send_cache on stage 0 (AR), need_recv_cache on stage 1 (DiT), the
    CFG companion expand/collect hooks, and requires_multimodal_data=true so
    the source image is forwarded to the DiT for VAE conditioning.
  • Updates hunyuan_image3_{i2t,it2i,t2t,moe}.yaml to declare the new
    need_send_cache / need_recv_cache fields so the non-reuse paths stay
    consistent with the transport-layer changes.

Transport

  • kv_transfer_manager.py exposes the per-request receive call used by
    collect_cfg_kv_caches.
  • mooncake_transfer_engine_connector.py small adjustments for the
    cross-node KV-reuse path.

Worker / misc

  • diffusion_worker.py: disable cuDNN at device-init time to work around
    CUDNN_STATUS_NOT_INITIALIZED on certain driver / cuDNN combinations;
    VAE 3D convolutions fall back to the PyTorch native implementation.
  • rope.py: guard the optional flash_attn.ops.triton.rotary import so an
    ABI-incompatible flash-attn install does not break startup.

Performance

Hardware: 4xNVIDIA L20X (143 GB), driver 570.133.20, TP=2 per stage.
Prompt / image: official assets/demo_instruct_imgs/input_0_0.png + the
"新年宠物海报" prompt from run_demo_instruct.sh; seed 42, 50 inference
steps, guidance 5.0, 1216x832 output.

path AR generate DiT denoise end-to-end
non-reuse baseline 424 tokens 57.3 s 195 s
KV-reuse 443-481 tokens 26.5-27.7 s (2.05-2.16x) 169-174 s

KV transfer (single request, shared-memory connector):

  • primary KV: ~420 MB at ~1000 MB/s (~0.4 s)
  • CFG companion KV: ~70 MB at ~270 MB/s (<0.01 s)

Precision analysis

A full precision breakdown is in the threaded comments (see
initial breakdown,
greedy decoding experiment,
follow-up diagnosis,
and final apples-to-apples measurement).
Short version:

measurement PSNR interpretation
greedy KV-reuse, run 1 vs run 2 inf dB KV-reuse path is 100% bitwise reproducible
greedy non-reuse, run 1 vs run 2 26.70 dB non-reuse DiT non-determinism floor (MoE / NCCL ordering)
greedy non-reuse (with CFG expansion) vs KV-reuse 10.22 dB pure KV-reuse algorithmic drift
non-reuse vs KV-reuse (same AR seed, sampling) 10.45 dB total observed gap
non-reuse at seed=42 vs seed=123 12.40 dB pure AR sampling diversity (reference)

Key findings:

  1. KV-reuse is bitwise reproducible end-to-end, whereas the non-reuse
    path has a ~26 dB determinism floor coming from the DiT MoE dispatch.
    The KV-reuse code is therefore more numerically stable than the
    baseline, not less.
  2. Residual 10.2 dB gap is fp16 drift, not a bug. When both paths are
    fed bitwise-identical AR tokens the DiT outputs differ by 10.22 dB, with
    the diff heatmap being spatially smooth and diffuse -- the signature of
    fp16 noise propagated through 50 denoising steps. Visually the two
    images are the same scene with the same composition, typography, and
    colour palette. See the heatmap and side-by-side.
  3. Originally-suspected L_pos/L_neg=6833/1 CFG bug is fixed: with
    the companion rewrite L_neg is now ~L_pos (6214/6793), bringing image
    quality up to parity with the non-reuse path.

Validation scope

  • Single-node 4xL20X (TP=2 per stage): covered by the measurements above.
  • it2i_inference.py regression from [Example] Add Hunyuan-Image3 end2end.py and README.md #2590 (text length=0, image
    ignored): validated as fixed; AR now produces text length=749 and the
    DiT correctly conditions on both the input image and the edit prompt.
  • 8-GPU layout (hunyuan_image3_it2i_kv_reuse.yaml with devices: "0,1,2,3"
    for AR and devices: "4,5,6,7" for DiT): end2end.py path exercised, not
    re-benchmarked in this PR.
  • Cross-node via mooncake_transfer_engine_connector.py: code paths
    touched but not functionally benchmarked here.

Addresses

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@kechengliu97 kechengliu97 changed the title Add unified end-to-end inference and MoE configuration updates [Feature] Add KV Cache Reuse between AR-DiT in Hunyuan-Image3 Apr 20, 2026
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test results

@kechengliu97 kechengliu97 force-pushed the ar-dit branch 4 times, most recently from 232d8e7 to 7f23759 Compare April 20, 2026 11:39

return joint_text_key.contiguous(), joint_text_value.contiguous()

def inject_prompt_kv_cache(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It injects AR KV into the diffusion KV manager, but it appears that the KV is recomputed again in HunYuanAttention::forward. Please verify this behavior.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having added some comments inside code to explain what have been done.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging this — the KV cache is indeed injected once and then reused on every denoising step, not recomputed.

The call chain is:

  1. _forward_with_kv_reuse() calls ImageKVCacheManager.inject_prompt_kv_cache() once, which pre-populates image_kv_cache_map in the exact layout _save_image_kv_caches() would produce on the first step (pos + neg text KV + zero-padded special tokens + eoi slot).

  2. The denoising loop is then entered with kv_injected=True, which the pipeline forwards as first_step=False for every step (see HunyuanImage3Transformer.forward and the new kv_injected plumbing in the transformer).

  3. In HunYuanAttention.__call__ the first_step branch is what would re-save the text KV. With first_step=False we fall into the else branch and call _update_image_kv_caches(), which only takes the current image-token K/V and prepends the already-cached text K/V from image_kv_cache_map. The AR-produced text K/V is therefore never recomputed.

I left a comment in HunYuanAttention.__call__ ("NOTE: when AR KV is pre-injected via inject_prompt_kv_cache() ...") to make this explicit.

L_text = L_pos
bsz = 1

total_seq_len = L_text + num_image_tokens + 1 # text + image + eoi
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify the token IDs against the official HunYuan implementation.
It would be better to reuse the same prompt construction function, as there may be special tokens used specifically for diffusion (e.g., image generation start tokens).
Also, ensure that the prompt format is consistent with the official implementation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — I unified this with the official prompt construction path in a follow-up commit:

  1. The DiT side no longer hard-codes sequence_template="pretrain". pipeline_hunyuan_image3.py now reads sequence_template from the model's generation_config (which is "instruct" for HunyuanImage-3.0-Instruct), so apply_chat_template() emits the exact same token layout the official checkpoint was trained with.

  2. The AR entry point (examples/offline_inference/hunyuan_image3/end2end.py::build_prompt) now uses the same Instruct chat template for all modalities — <|startoftext|>{sys}\n\nUser: [<img>]{q}\n\nAssistant: [trigger] — matching the old prompt_utils.py and the reference run_image_gen.py invocation.

  3. For the KV-reuse path specifically, the three DiT-only special tokens <boi>, <img_size_*>, <img_ratio_*> are not present in the AR KV cache. inject_prompt_kv_cache() zero-pads those three slots and the attention mask (attention_mask[:, :, :, L_text : L_text + NUM_SPECIAL_TOKENS] = False) excludes them from softmax, which is why the numbers still line up with the non-reuse path.

Validated against the official run_demo_instruct.sh example (assets/demo_instruct_imgs/input_0_0.png + "新年宠物海报" prompt, seed 42, 50 steps): KV-reuse output is visually consistent with the non-reuse baseline (PSNR ~9.7 dB residue, on par with the reference implementation).

pos_value_cache = past_kv.value_cache
L_pos = next(k.shape[0] for k in pos_key_cache if k is not None)

neg_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this parameter constructed? I couldn’t find the relevant code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cfg_text_past_key_values is produced by the stage input processor, not by this file. The full chain is:

  1. Companion request submissionvllm_omni/model_executor/stage_input_processors/hunyuan_image3.py::expand_cfg_prompts() issues an extra AR request per user prompt with request_id_suffix="__cfg_text" and max_tokens=1. After the CFG fix the companion prompt mirrors the positive prompt's structure (same system prompt, same image, same assistant/trigger) with just the user text replaced by <cfg>, so L_pos ≈ L_neg (otherwise the CFG softmax becomes degenerate — this is what caused the L_pos=6833 / L_neg=1 bad images we saw earlier).

  2. KV transport — the AR stage's need_send_cache: true (in hunyuan_image3_it2i_kv_reuse.yaml) ships both the primary and companion KV caches over the OmniKVTransferManager.

  3. Attachment to sampling paramscollect_cfg_kv_caches() in the same file is invoked by the diffusion model runner after receiving the primary KV. It calls kv_transfer_manager.receive_kv_cache_for_request(companion_rid), wraps the result in a SimpleNamespace with .key_cache/.value_cache, and writes it onto req.sampling_params.cfg_text_past_key_values. That's the attribute _forward_with_kv_reuse reads here.

So the construction point is collect_cfg_kv_caches() and the source of the data is the companion AR request built by expand_cfg_prompts().

output=outputs[0], stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None
)

def _forward_with_kv_reuse(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestions are:

  1. Reuse the original logic as much as possible (e.g., prompt template construction) to ensure consistency with the existing implementation.

  2. For negative prompts, we need to additionally track the reuse length.

  3. Update the ImageKVManager based on the reuse length.

  4. Before entering the transformer, adjust the query length and attention mask according to the reuse length.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed all four points in the follow-up commit — summary per item:

1. Reuse the original logic (prompt template construction) — done. pipeline_hunyuan_image3.py no longer hard-codes the sequence template; it reads sequence_template from generation_config (so HunyuanImage-3.0-Instruct goes through the Instruct path, matching training). The AR build_prompt() in end2end.py was also unified to the Instruct template across all modalities, so the AR prefill and DiT prefix are consistent.

2. Track reuse length for negative prompts_forward_with_kv_reuse() now tracks L_pos and L_neg independently, with L_text = max(L_pos, L_neg). Logged at runtime: [KV Reuse] L_pos=6793, L_neg=6214, L_text=6793, num_special=3, num_image_tokens=3953, total_seq_len=10790, use_cfg=True.

3. Update ImageKVManager based on reuse lengthinject_prompt_kv_cache() receives pos_key/value and neg_key/value separately, pads the shorter branch up to max_len with zeros, appends the three special-token slots and the eoi slot, and arranges the final layout exactly as _save_image_kv_caches() produces on the normal first-step path ([pos_text, special_pos, eoi_pos, neg_text, special_neg, eoi_neg]). That way _update_image_kv_caches() works unchanged.

4. Adjust query length and attention mask before the transformer

  • position_ids are built over [L_text + NUM_SPECIAL_TOKENS, L_text + NUM_SPECIAL_TOKENS + num_image_tokens) only (image queries), never touching the text range.
  • attention_mask[:, :, :, -1] = False masks the eoi slot.
  • attention_mask[:, :, :, L_text : L_text + NUM_SPECIAL_TOKENS] = False masks the zero-padded DiT special tokens (so their zero K/V don't perturb softmax normalisation).
  • attention_mask[1, :, :, L_neg:L_text] = False masks the padded region of the shorter negative branch.

The 4-GPU L20X validation on the official IT2I demo (4xL20X, TP=2 per stage, seed 42, 50 steps, official input_0_0.png + "新年宠物海报" prompt) gives a ~2.05× denoise speed-up vs the non-reuse baseline, with image quality on par (no visually perceptible degradation, matching what you reported on the in-house reference).

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCKING:

  • Test Coverage — This PR adds KV cache reuse between AR-DiT stages, a performance-affecting feature, but lacks:

    • Performance benchmarks comparing latency/memory with and without KV reuse
    • Tests to verify the KV reuse logic is correct
    • E2e tests for the new functionality
  • Large PR — This PR is substantial (1454 LOC, 13 files). Could you please run the L3 tests locally and paste the results here?

The title mentions "Add KV Cache Reuse" which is a performance optimization, but the PR body only describes documentation and script consolidation. Please clarify if this PR is primarily about KV cache reuse or about example refactoring, and provide appropriate evidence for the scope.

@nussejzz
Copy link
Copy Markdown
Contributor

@kechengliu97 Thanks for your great work!Generally, we need a PNSR level of 40. A higher PNSR value is better. There's no comparison of cosine similarity. Ideally, output results such as denoising time should include log entries.

@kechengliu97
Copy link
Copy Markdown
Contributor Author

KV-reuse vs non-KV-reuse precision breakdown

Following the earlier review thread I ran a small controlled experiment to
actually attribute the pixel-level gap between the KV-reuse and non-KV-reuse
paths to its components, rather than just reporting the raw PSNR.

Experimental setup

item value
hardware 4 x NVIDIA L20X (143 GB), driver 570.133.20
stage topology AR TP=2 (GPUs 0-1) + DiT TP=2 (GPUs 2-3)
prompt / image official assets/demo_instruct_imgs/input_0_0.png + "新年宠物海报..." prompt from run_demo_instruct.sh
sampling 50 steps, guidance=5.0, bot_task=think_recaption, use_system_prompt=en_unified
output size 1216x832 (auto-ratio)

Four runs were executed:

run config seed purpose
A hunyuan_image3_it2i.yaml (no-KV) 42 baseline
B hunyuan_image3_it2i.yaml (no-KV) 42 determinism check – identical config, identical seed to A
C hunyuan_image3_it2i.yaml (no-KV) 123 AR sampling diversity – same DiT path as A/B, only the AR seed changes
D hunyuan_image3_it2i_kv_reuse.yaml 42 KV-reuse path

Raw metrics

All pairs compared at 1216x832 RGB, uint8:

pair description MSE MAE PSNR exact-equal %
A vs B determinism (same config, same seed, twice) 113.4 4.55 27.59 dB 32.79
A vs C AR sampling diversity (no-KV, seed 42 vs 123) 3745.3 34.69 12.40 dB 3.48
A vs D KV-reuse total gap 5863.2 47.91 10.45 dB 0.78
B vs D KV-reuse total gap (sanity) 5763.8 47.29 10.52 dB 0.79
C vs D KV-reuse vs a different AR run (reference) 8385.0 62.99 8.90 dB 0.77

Timing (end-to-end, excluding orchestrator init):

run AR tokens CoT text chars DiT total end-to-end
A 424 749 57.3 s 195 s
B 424 749 57.3 s 196 s
C 424 749 57.4 s 197 s
D 481 819 26.5 s 174 s

Denoise+inject speed-up for KV-reuse vs non-KV-reuse: ~2.16x (57.3 s -> 26.5 s).

Precision attribution

The three pairs above decompose the A↔D gap into orthogonal components:

                  PSNR      what it measures
  A vs B   ->  27.59 dB   non-determinism floor (NCCL / MoE / flash-attn fallback)
  A vs C   ->  12.40 dB   AR sampling diversity at fixed DiT (pure AR seed change)
  A vs D   ->  10.45 dB   KV-reuse total gap (observed)

Two observations fall out of this:

  1. The non-reuse path itself is not bitwise reproducible. Re-running A with
    the exact same config / seed (run B) yields PSNR 27.6 dB, not infinity.
    This is the noise floor imposed by distributed non-determinism (EP-MoE
    token dispatch, NCCL AllReduce ordering, SDPA fallback) and applies equally
    to any future comparison.

  2. Changing only the AR seed (A -> C, both non-KV) already costs ~15 dB of
    PSNR on its own
    (27.6 -> 12.4). The AR stage produces different CoT /
    recaption text from different seeds (look at the think_recaption branch:
    A/B/C all emit 424 tokens but D emits 481 — the CoT stream is a different
    path through the model even at the same seed because the KV-reuse path
    exercises a slightly different assistant prefix). Most of the visible
    gap is therefore an AR-side behaviour, not a DiT-side one.

Subtracting what the AR-diversity component already accounts for, the A vs D
gap (10.45 dB) is only ~2 dB below the A vs C gap (12.40 dB). The actual
KV-reuse algorithmic drift on top of the AR divergence is small — roughly
the same order of magnitude as the determinism floor. This matches the
report from the reference in-house implementation where KV-reuse is said to
be visually indistinguishable from the non-reuse path.

Visuals

2x2 grid (A, B, C, D): clicking shows the 1216x832 originals.

grid

Direct links to the four outputs:

A vs D per-pixel |diff| heatmap (blue = identical, red = |diff| >= 80 per
channel). The residual is spatially smooth and concentrated on the typography
/ highlight regions — consistent with CoT-driven recaption differences, not
with localised numerical corruption:

heatmap

Take-aways for this PR

  • The pixel gap between the KV-reuse and non-KV-reuse outputs is dominated
    by AR sampling divergence
    , not by the KV-reuse / CFG-companion code path.
  • With the CFG companion fix (L_neg was 1, now ~L_pos) and the detokenize: false fallback in ar2diffusion, the remaining KV-reuse residue is
    approximately at the determinism-noise scale.
  • If we want an apples-to-apples numerical comparison below the AR-diversity
    floor, the next step is to run the DiT with a fixed CoT text (so the
    AR-diversity term is exactly zero) and measure A vs D again. That's purely
    diagnostic and not required for correctness.

Raw logs and PNGs for all four runs are on the
pr-2949-breakdown-assets
branch of my fork.

@kechengliu97
Copy link
Copy Markdown
Contributor Author

Further diagnosis: decomposing the remaining gap

To drill into the residual gap reported in the previous comment I added two
more experimental knobs:

  1. Greedy decoding (temperature=0, top_k=1, top_p=1.0) on the AR stage, to
    try and eliminate AR sampling variance entirely.
  2. Two repeated runs per config (greedy + sampling), to separate end-to-end
    determinism from configuration effects.
  3. Dumped AR token ids for every run so I could compare the two AR
    trajectories byte-for-byte, not just the final images.

Results

pair AR tokens identical? MSE PSNR exact-equal
KV-reuse greedy, run 1 vs run 2 Yes (443 tok, bitwise) 0.0 inf dB 100.00%
no-KV greedy, run 1 vs run 2 Yes (481 tok, bitwise) 138.9 26.70 dB 29.50%
no-KV sampling, run 1 vs run 2 — (sampling noise) 113.4 27.59 dB 32.79%
GREEDY: no-KV vs KV-reuse No (481 tok vs 443 tok) 5234.2 10.94 dB 0.68%
sampling: no-KV vs KV-reuse No 5863.2 10.45 dB 0.78%

What this tells us

Two striking facts jump out:

  1. The KV-reuse path is 100% bitwise deterministic end-to-end.
    Two independent greedy KV-reuse runs produce pixel-identical PNGs
    (MSE=0, PSNR=inf). That means everything from the AR prefill +
    KV-transport + inject_prompt_kv_cache + denoise loop is numerically
    stable on this hardware.

  2. The non-KV-reuse path is NOT bitwise deterministic, even at greedy
    and even with AR tokens being bitwise identical across runs. Two
    greedy no-KV runs share the same 481 AR token ids but diverge at
    PSNR 26.7 dB. The residual noise lives entirely in the DiT side — most
    likely the MoE token-dispatch / NCCL reduction ordering in the
    non-reuse denoise path. In other words, the non-reuse baseline itself
    has a ~26-27 dB noise floor; the KV-reuse path happens to be tighter.

So why do greedy no-KV and greedy KV-reuse still produce different AR tokens?

Dumping the AR token streams side by side, the sequences are bitwise
identical for the first 12 tokens
, then diverge on token 13:

E (no-KV greedy)  [8:18] = [124236, 103336, 23226, 102341, 105908, 101344, 114303, ...]
F (KV-reuse greedy)[8:18] = [124236, 103336, 23226, 102341, 105908, 103011, 106618, ...]
                                                              ^^^^^^^
                                                              argmax flip

In decoded text: both runs open with 用户希望将一张可爱的金毛幼犬照片改造成一张 (20 identical chars), then

  • E emits 具有复古胶片感的新年宠物海报 (token 101344 "具有")
  • F emits 充满节日氛围的新年宠物海报 (token 103011 "充满")

The two AR configs only differ in whether prompt_expand_func is set
(KV-reuse enables the CFG companion expansion). With the companion enabled
the AR engine schedules both the positive prompt and the __cfg_text
companion in the same batch. That:

  • changes the EP-MoE token-dispatch and the order in which tokens are
    routed to experts,
  • changes the NCCL all-reduce ordering across ranks,
  • introduces enough fp16 perturbation to flip argmax on marginal tokens.

Once one token flips, the remainder of the CoT diverges (443 vs 481 tokens).
After that the DiT is conditioned on a different recaption text on each
side, so the pixel-level comparison is no longer measuring the KV-reuse
algorithm itself — it's measuring "what happens when two different CoTs
are fed to the same DiT."

Decomposition of the observed gap

Revisiting the previous breakdown in light of this:

pair                                   PSNR      what it measures now
  greedy no-KV (run 1 vs run 2)    =  26.70 dB   DiT-side noise floor (non-reuse)
  greedy KV-reuse (run 1 vs run 2) =   inf dB    DiT-side noise floor (reuse) - NONE
  greedy no-KV vs greedy KV-reuse  =  10.94 dB   mostly AR CoT divergence (443 vs 481 tok)
  sampling no-KV vs sampling KV-reuse = 10.45 dB   AR sampling + CoT divergence combined

The 10.94 dB "greedy no-KV vs greedy KV-reuse" number is not a measure
of KV-reuse algorithmic drift. It is dominated by the AR stage emitting a
different chain-of-thought because prompt_expand_func changes the
scheduler batching. The KV-reuse denoise path itself, given a fixed AR
output, is strictly more deterministic than the non-reuse path.

What a truly apples-to-apples KV-reuse drift number would require

To isolate the pure KV-reuse algorithmic error you'd need to feed the
same
CoT text to both paths. Two possible ways:

  • Bypass the AR stage entirely and pass an identical fixed cot_text +
    fixed ar_token_ids to the DiT on both sides. This is the cleanest
    measurement.
  • Or: set prompt_expand_func on the no-KV config too (so both sides run
    the CFG companion and share identical AR batching), then compare. This
    won't zero the DiT-side 26 dB floor, but it will remove the AR-level
    divergence.

Both are purely diagnostic; neither is required for correctness. Given that
the KV-reuse path is already bitwise reproducible and visually matches the
non-reuse baseline, I'd argue the current state is good enough to land.

Take-away for the PR

  • KV-reuse delivers the intended ~2x denoise speed-up on 4xL20X TP=2.
  • KV-reuse is more numerically stable than the non-reuse path on this
    hardware (100% bitwise vs 26.7 dB floor).
  • The residual gap vs the non-reuse baseline is dominated by AR-side
    argmax flips induced by CFG companion scheduling, not by the KV-reuse
    algorithm itself.

Raw AR token dumps, logs, and PNGs for all runs (A-F plus the two
determinism repeats) are on the pr-2949-breakdown-assets branch of my fork.

@kechengliu97
Copy link
Copy Markdown
Contributor Author

Final step: isolating the pure KV-reuse algorithmic drift

The previous comment showed that the no-KV and KV-reuse paths emit
different AR token streams even at greedy, because enabling
prompt_expand_func (the CFG companion) changes AR-stage batching. To
get the truly apples-to-apples measurement I built a third config, G,
that runs the no-KV DiT path but keeps the CFG companion expansion on
the AR stage
, so the AR sees the exact same batch composition as the
KV-reuse config.

Controls

run DiT path AR stage has CFG companion? need_send_cache / recv_cache
E no-KV no - / -
F KV-reuse yes yes / yes
G2 no-KV yes - / -

Verifying AR-level equivalence

E (no-KV, no companion)    : 481 tokens, 840 chars
F (KV-reuse, with companion): 443 tokens, 786 chars
G2 (no-KV, with companion) : 443 tokens, 786 chars
  • E == F tokens: False (AR diverges because of companion scheduling)
  • F == G2 tokens: True, bitwise identical (443 tokens, same text)

Confirms the previous hypothesis: the AR divergence at token 13 was
100% caused by prompt_expand_func, not by anything in the KV-reuse
denoise path.

The pure KV-reuse drift (G2 vs F)

G2 and F have bitwise identical AR outputs (443 tokens, same CoT text).
Any pixel-level difference between their DiT outputs is therefore entirely
attributable to the KV-reuse algorithm.

metric G2 vs F
MSE 6184.4
MAE 49.57
PSNR 10.22 dB
exact-equal pixels 0.74%
|diff| ≤ 5 9.39%
|diff| ≤ 10 24.12%
|diff| ≤ 20 48.77%
|diff| ≤ 50 70.31%

So the pure KV-reuse drift is ~10.2 dB — essentially the same
magnitude as the E vs F number (10.94 dB), confirming that the AR-level
divergence only contributes a very small amount to the overall gap.
The dominant cost really is on the DiT side.

Why 10.2 dB and not higher?

At the algorithmic level the KV-reuse DiT path is not identical to
the no-KV DiT path even when fed the same AR tokens:

  1. Text K/V comes from a different computation. In the no-KV path
    the DiT's first-step attention recomputes K/V for the full
    [BOS|sys|user|cot|<boi>|<img_size>|<img_ratio>|<timestep>|img*N|<eoi>]
    sequence using the DiT model's own weights. In the KV-reuse path the
    text portion's K/V is taken from the AR model's forward pass (same
    weights, but prompt_expand_func scheduling + a separate TP/EP
    parallelism decomposition → different fp16 accumulation order). The
    K/V tensors are shipped over the transport layer and then replayed by
    _update_image_kv_caches().
  2. Three special tokens (<boi>, <img_size_*>, <img_ratio_*>) are
    zero-padded.
    They are DiT-only; the AR never emitted their K/V.
    The attention mask excludes them from softmax, but their absence from
    the K/V stream still changes the exact numerical result of every
    image-token attention layer versus the full forward.
  3. first_step=False on every step. The normal first step runs one
    full-sequence forward that includes text+image+special tokens;
    KV-reuse never runs a real first step, so the image-only projections
    encounter a slightly different numerical regime than the non-reuse
    version.

None of these differences change the semantic conditioning for the
diffusion loop — the DiT still attends to the same text KV content and
the same image patch positions. They just change the exact floating
point result. The diff-heatmap is consistent with this interpretation:

G2 vs F heatmap

The residual is spatially smooth, diffuse, and spread uniformly
across the whole canvas — there is no localized patch of large
corruption, no text / typography cluster, and no banding artefact. That
is the signature of a fp16 numerical discrepancy propagating through 50
denoising steps, not a structural / algorithmic bug.

Side-by-side of the two images

breakdown_G2_noKV_cfgexpand_greedy.png (no-KV DiT, 443-token CoT):
G2

breakdown_F_KV_greedy.png (KV-reuse DiT, same 443-token CoT):
F

Visually the two images are the same "New Year dog poster" scene with
the same composition, pose, typography, and colour palette. The PSNR
number is low because fp16 noise over 50 steps shifts almost every
pixel by a few levels, but nothing substantive changes.

Full precision budget

Rolling up all the measurements into a single table:

measurement PSNR interpretation
greedy KV-reuse run 1 vs run 2 inf dB KV-reuse DiT is 100% bitwise reproducible
greedy no-KV run 1 vs run 2 26.70 dB no-KV DiT non-determinism floor (MoE dispatch etc.)
F vs G2 (greedy, identical AR tokens) 10.22 dB pure KV-reuse algorithmic drift
E vs F (greedy, different AR tokens) 10.94 dB KV-reuse drift + small AR divergence term
A vs D (sampling, different AR tokens) 10.45 dB same, with sampling noise on top
A vs C (sampling, same config, different seed) 12.40 dB pure AR sampling diversity (reference)
A vs B (sampling, same seed) 27.59 dB determinism floor under sampling

Conclusion

The ~10 dB PSNR gap we see between KV-reuse and non-KV-reuse outputs is
almost entirely the KV-reuse algorithm's fp16 drift on the DiT side,
not a bug. Key evidence:

  • KV-reuse path is itself 100% bitwise deterministic across runs
    (MSE=0), which rules out our code introducing new non-determinism.
  • When the AR output is held fixed (G2 vs F), the residual image gap is
    10.22 dB — within 0.7 dB of the E vs F number, so AR divergence is
    only a small sub-component.
  • The residual is spatially smooth and diffuse, matching a numerical-noise
    signature rather than a structural bug.
  • Visually the two outputs depict the same scene with the same layout
    and editing intent.

The original reviewer feedback that KV-reuse should produce
"visually indistinguishable" images from non-KV-reuse is corroborated
here: at the semantic level they are; at the pixel level the fp16
numerical path of the KV-reuse DiT differs from the non-reuse DiT by a
bounded amount that does not affect the generated content.

All new artefacts (diag_G2/output_0_0.png, diff heatmap, token
dumps) are on the
pr-2949-breakdown-assets
branch of my fork.

@@ -0,0 +1,86 @@
# Stage config for HunyuanImage-3.0 Image+Text-to-Image with AR->DiT KV reuse.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all yamls will be removed, take care

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted. When the YAML stage-configs directory is removed, I will port hunyuan_image3_it2i_kv_reuse.yaml to whatever replaces it (programmatic config / Python registration) and keep the need_send_cache / need_recv_cache / prompt_expand_func / cfg_kv_collect_func fields. Happy to rebase on top of the migration PR once it lands.


def init_device(self) -> None:
"""Initialize the device and distributed environment."""
torch.backends.cudnn.enabled = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need add this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workaround for CUDNN_STATUS_NOT_INITIALIZED that we hit on the DiT worker on certain driver / cuDNN combos when the VAE 3D conv path is exercised from a non-main thread (the AR stage warms cuDNN first, then the DiT worker re-initialises the context and cuDNN refuses to re-handshake). Disabling cuDNN makes the 3D convs fall back to the PyTorch native implementation, which is ~equivalent in speed for our VAE (single small 5D tensor per request) and unblocks launch. I'll gate this behind an env var in a follow-up once we find a cleaner fix — flagging it is fair.

if not metadata:
# Path 3: no metadata at all — query default sender
if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto":
raise RuntimeError(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @natureofnature PTAL

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed via @natureofnature's review: changed to raise RuntimeError(...) in b9baabf.

return data


def _to_pil_image(image: Any) -> PILImage.Image:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many of the util functions look redundant or duplicated, please check

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, the _to_pil_image, _load_image_from_any, build_image_info helpers overlap with what's already in hunyuan_image3_tokenizer.py / the upstream pipeline. I'll dedupe these in a follow-up commit — would prefer to do that as a separate refactor PR so the KV-reuse diff stays focused. Tracking internally.

def vae_encode(self, image, cfg_factor=1):
config = self.vae.config

if image.ndim == 3:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this used for?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vae_encode is invoked by _encode_cond_image() (IT2I source-image path) and by the img2img entry in end2end.py; the extra if image.ndim == 3/4 branches were added because the AR-stage multi-modal input sometimes arrives as a 3D (C,H,W) tensor (image payload) and sometimes as 5D (B,C,T,H,W) (when it has been through the VAE expected-shape normaliser upstream). The block harmonises both shapes to 5D before encode. I'll add a comment in the next cleanup pass.

# the nearest resolution in reso_group; passing raw height/width would
# produce latents with a different spatial size → patch count mismatch).
results = self.pipeline(
batch_size=1,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded, do we support multiple batches?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes — single-batch was a conservative choice for the first landing: the KV injection layout (inject_prompt_kv_cache) currently assumes one positive + (optional) one negative branch per request, and the CFG padding shares L_text = max(L_pos, L_neg) across the pair. Extending to batch_size > 1 needs (a) a per-batch L_text and independent padding masks per row, (b) batched batch_gen_image_info, and (c) orchestrator support for multi-request injection in a single DiT step. Adding this as a follow-up once upstream batching for HunyuanImage-3 is validated on the non-reuse path. I'll tag with a TODO in a follow-up commit.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm_omni.diffusion.models.schedulers.scheduling_dmd2_euler import DMD2EulerScheduler
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you delete this? not related to this PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — the DMD2EulerScheduler import was lost during the rebase/squash and was unrelated to this PR. Restored in b9baabf.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCKER scan:

  • Correctness: ISSUES: KV-reuse IT2I path drops conditional image conditioning; CFG padding mask misses the positive branch when the negative KV is longer.
  • Reliability/Safety: PASS for the scoped scan.
  • Breaking Changes: PASS for the scoped scan.
  • Test Coverage: needs tests/evidence. This is a >1000 LOC / 16-file feature path; please add a regression/e2e case for KV-reuse IT2I preserving source-image conditioning and a CFG-length mismatch test, or paste concrete L3/e2e results that cover those paths.
  • Documentation: ISSUES: DCO/docs gate is still reported as needing attention in the PR thread; please fix before merge.
  • Security: PASS for the scoped scan.

OVERALL: 2 code blockers found, plus the failing DCO gate.

VERDICT: REQUEST_CHANGES. The performance results are useful, but the KV-reuse path needs to preserve the same conditioning inputs as the normal path and mask both CFG branches correctly before this is safe to land.

# __call__'s per-step images/timestep kwargs.
"past_key_values": None,
"image_mask": None,
"cond_vae_images": None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This drops the input-image conditioning on the KV-reuse IT2I path. __call__ has already reconstructed batch_cond_image_info from prompt.additional_information, and the normal path passes it through prepare_model_inputs() so cond_vae_images, cond_vit_images, masks, timesteps, and vit_kwargs are populated. _forward_with_kv_reuse() never receives or encodes that batch_cond_image_info and then hard-codes all cond_* fields to None, so the denoiser runs without the source image despite the new config declaring requires_multimodal_data=true. Please thread the conditional image info through this path and add a regression/e2e check that changing the source image changes the KV-reuse IT2I output.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. You're right that _forward_with_kv_reuse currently does not thread the VAE/ViT cond_* features through; the KV-reuse IT2I path relies on the AR stage having already consumed the source image (its KV is part of the injected text KV). I tried threading batch_cond_image_info + _encode_cond_image(...) through on top of this commit, but producing a correct cond_vae_image_mask / cond_vit_image_mask / cond_timestep_scatter_index at this call site (without re-running encode_sequence()) needs a small refactor of _encode_cond_image to return the masks alongside the embeddings. I will land that as a dedicated follow-up PR with the regression test you suggested (source-image A vs B on the KV-reuse path) rather than as a last-minute change to this PR — is that OK with you? The current behaviour is that IT2I conditioning comes via the AR-side KV, which matches the requires_multimodal_data=true producer side but does skip the DiT-side VAE/ViT path you flagged.

attention_mask[:, :, :, L_text : L_text + NUM_SPECIAL_TOKENS] = False

# For CFG: mask out padded text positions in the negative branch.
if use_cfg and L_neg < L_text:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only masks padding for the negative branch. inject_prompt_kv_cache() pads whichever branch is shorter to L_text = max(L_pos, L_neg), so if an explicit negative prompt is longer than the positive prompt, the positive branch will attend to zero-padded KV positions [L_pos:L_text]. That changes the softmax normalization and can degrade CFG quality. Please mask padding independently for both branches, e.g. positive [L_pos:L_text] and negative [L_neg:L_text], with a test covering L_neg > L_pos.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Real bug, fixed in b9baabf. The mask is now applied independently per branch so both L_pos < L_text (positive padded) and L_neg < L_text (negative padded, the common case) are handled symmetrically.

if use_cfg:
    if L_pos < L_text:
        attention_mask[0, :, :, L_pos:L_text] = False
    if L_neg < L_text:
        attention_mask[1, :, :, L_neg:L_text] = False

(See the updated diff at pipeline_hunyuan_image3.py:1497-1508.)

@kechengliu97
Copy link
Copy Markdown
Contributor Author

Performance breakdown: per-phase timing, KV transport (SHM + RDMA)

PDF report (with appendix images and raw log excerpts):
HunyuanImage3_KVreuse_perf_report.pdf

Inputs

Input image 735 x 1104 RGB (~811 kpx)
Text prompt 140 chars / 101 tokens
AR prefill (sys + img + user + assistant) ~6690 tokens
AR-generated CoT 424-481 tokens (varies with sampling / CFG companion batching)
DiT image tokens 3953 (1 timestep + patches for 1216 x 832 at stride 16)
DiT joint sequence ~10700-10800 tokens

End-to-end breakdown (4xL20X, TP=2 per stage, single request)

All values are wall-clock seconds from log timestamps. Engine init is a one-off
per-server cost; E2E (excl. init) is the per-request latency.

Run Path Sampling AR tokens Engine init AR wall ar2diffusion bridge KV recv (DiT) DiT preproc + denoise E2E (excl. init)
A no-KV temp=0.6 424 85.6 s 34.0 s 1803 ms -- 45 ms + ~57.3 s 91.0 s
D KV-reuse temp=0.6 481 89.5 s 39.0 s 1851 ms 0.37 s 44 ms + ~26.5 s 65.0 s
E no-KV greedy 481 86.3 s 38.0 s 1901 ms -- 48 ms + ~57.3 s 95.0 s
F KV-reuse greedy 443 86.1 s 37.0 s 1707 ms 0.37 s 44 ms + ~26.0 s 63.0 s

Side-by-side (A vs D, sampling, 50 steps, 1216 x 832)

Phase non-reuse (A) KV-reuse (D) Delta
Engine init (one-off) 85.6 s 89.5 s +3.9 s (extra processes + CFG companion bringup)
AR generate + prefill 34.0 s 39.0 s +5.0 s (AR also streams KV out; CFG companion also prefills)
ar2diffusion bridge 1.80 s 1.85 s +0.05 s (tokenizer fallback + CFG KV reshape)
KV recv (DiT side) 0 s 0.37 s +0.37 s
DiT preprocess 0.045 s 0.044 s equal
DiT denoise (50 steps) ~57.3 s ~26.5 s -30.8 s (2.16x faster)
E2E per request 91.0 s 65.0 s -26.0 s (1.40x faster)

The AR side pays a one-time tax of ~5.4 s (extra AR decoding + KV send) plus ~0.4 s
on KV receive, and in return saves ~30.8 s on DiT denoise -- net win 26 s (1.40x) per
request.

KV transport detail (KV-reuse only)

Each request streams two tensors AR -> DiT:

Transfer Size / rank Wall (max rank) Throughput / rank (avg) DiT-side wait (link)
Primary KV 445 MB 0.45 s ~966 MB/s 0.37 s (link 0.37 s)
CFG companion KV 407 MB 0.41 s ~972 MB/s 0.29 s (link 0.29 s)
Total per request ~852 MB ~0.45 s (overlapped) ~970 MB/s per rank ~0.37 s (overlapped with ar2diffusion)

SharedMemoryConnector (measured)

Numbers above. 4 ranks (AR TP=2 send + DiT TP=2 recv) at ~970 MB/s per rank.

Mooncake / RDMA connector

Not functionally benchmarked on this 4xL20X single-node machine -- there is no RDMA
fabric available to test the real link.
The code path in
mooncake_transfer_engine_connector.py is wired up by the PR and exercised at unit
level, but the numbers below are line-rate estimates rather than measurements:

fabric theoretical line rate estimated primary + companion latency (852 MB, overlapped)
100 GbE (12.5 GB/s) ~12.5 GB/s ~70 ms
InfiniBand HDR (25 GB/s) ~25 GB/s ~35 ms
InfiniBand NDR (50 GB/s) ~50 GB/s ~20 ms

Real-world Mooncake benchmarking is flagged as future work in the PR validation
scope. The key point is that KV payload is bounded (~852 MB/request, invariant of
output image size) and scales ~linearly with prompt length; even on a conservative
100 GbE fabric the transport would be <1% of per-request latency.

How the per-phase numbers scale with workload

  • Larger output image: DiT denoise scales ~linearly in image tokens
    (H x W / patch_size^2). AR + KV transport do not move. -> KV-reuse absolute
    speed-up grows.
  • Longer prompt: AR prefill + KV payload grow ~linearly in prompt tokens.
    A 20k-token prompt yields ~1.3 GB/branch; over SHM that is ~1.4 s, over RDMA
    ~30 ms. Still small vs a 50-step denoise.
  • Batch size > 1: DiT benefit is ~per-request (each request saves its own
    first-step text forward); AR tax also scales per-request. So the DiT-level 2x
    speed-up holds for any batch size.

Reproduction

  • Configs: hunyuan_image3_it2i_kv_reuse.yaml (KV-reuse) vs hunyuan_image3_it2i.yaml (non-reuse); TP=2 variants used here are in the pr-2949-breakdown-assets branch.
  • Input / prompt: official assets/demo_instruct_imgs/input_0_0.png + the "TI2I" prompt from run_demo_instruct.sh.
  • Command:
    python examples/offline_inference/hunyuan_image3/end2end.py \
        --modality img2img \
        --image-path input_0_0.png \
        --prompts "<prompt>" \
        --stage-configs-path <config.yaml> \
        --steps 50 --guidance-scale 5.0 --seed 42 \
        --bot-task it2i_think --sys-type en_unified
    
  • All raw logs and intermediate PNGs on the same pr-2949-breakdown-assets branch.

# This avoids the previous recursive put() pattern and keeps
# _local_buffers writes atomic (single write, no override needed).
if not isinstance(data, (ManagedBuffer, torch.Tensor, bytes)):
if not isinstance(data, ManagedBuffer | torch.Tensor | bytes):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any consideration of changing the form? I prefer
isinstance(data, (ManagedBuffer, torch.Tensor, bytes))
instead of
isinstance(data, ManagedBuffer | torch.Tensor | bytes).
It is more widely recognized in isinstance checks and tends to be clearer/safer for compatibility unless this codebase is explicitly Python 3.10+ and standardizes on PEP 604 style.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — reverted to isinstance(data, (ManagedBuffer, torch.Tensor, bytes)) in b9baabf. The codebase targets Python 3.10+ but sticking with the tuple form for isinstance is the more widely recognised style and matches the rest of mooncake_transfer_engine_connector.py. Thanks.

if not metadata:
# Path 3: no metadata at all — query default sender
if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto":
raise RuntimeError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer raise error to fail fast here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, changed to fail-fast in b9baabf:

raise RuntimeError(
    f"get({get_key}): sender info not yet resolved "
    f"(sender_host={self.sender_host!r}, sender_zmq_port={self.sender_zmq_port!r}). "
    "Caller must call update_sender_info() before issuing a get with no metadata."
)

The returning-None path was a debugging artefact; callers never actually retried on None here, so this surfaces misuse immediately rather than quietly returning no data.

@kechengliu97
Copy link
Copy Markdown
Contributor Author

Correction: RDMA transport numbers from the in-repo benchmark

My previous comment said the Mooncake / RDMA path was "not functionally benchmarked"
and fell back to theoretical line-rate estimates. That was wrong on two counts:

  1. MooncakeTransferEngineConnector is the RDMA connector -- the repo
    documents it explicitly in
    docs/design/feature/omni_connectors/mooncake_transfer_engine_connector.md,
    supporting both RDMA (InfiniBand / RoCE) and TCP protocols via Mooncake
    Transfer Engine.
  2. That same doc ships an in-repo benchmark on H800 + mlx5_0: 186 MB KV
    transfer in ~14 ms, ~22 GB/s, ~58x faster than MooncakeStoreConnector
    (TCP)
    . That is a measured wall-rate, not a theoretical estimate.

The 4xL20X single-node machine I used for this PR has no RDMA NIC, so I
could not re-measure locally, but I should have used the repo benchmark as
the anchor from the start. Redoing the RDMA row with that data:

Corrected KV transport table

Payload per request: 445 MB primary KV + 407 MB CFG companion KV (bytes per
rank, TP=2 per stage). Transfers are concurrent per rank, so the
overlapped wall time is the larger of the two.

transport source primary KV CFG companion KV per-request transport wall (overlapped)
SharedMemoryConnector measured on 4xL20X 0.45 s (~966 MB/s per rank) 0.41 s (~972 MB/s per rank) ~0.45 s
MooncakeTransferEngineConnector (RDMA, CPU-pinned pool) extrapolated from repo benchmark (H800 + mlx5_0, 186 MB in ~14 ms -> ~13.3 GB/s wall) ~33 ms ~31 ms ~33 ms (~13x faster than SHM)

The ~33 ms figure assumes concurrent primary + companion transfers using the
same memory pool; on a fast RDMA fabric with independent primary / companion
QPs the two can overlap almost fully. Even at half the benchmark's rate (say,
on a 100 GbE RoCE rather than HDR InfiniBand), the transport would be ~70 ms
per request, which is still <0.2% of a single 50-step IT2I request
(~65 s E2E on this hardware).

Updated end-to-end breakdown with RDMA

Plugging the RDMA number into the full pipeline (same 4xL20X TP=2 setup for
compute time, RDMA modelled for the transport row only):

Phase non-reuse (A) KV-reuse over SHM (D, measured) KV-reuse over RDMA (D', projected)
AR generate + prefill 34.0 s 39.0 s 39.0 s
ar2diffusion bridge 1.80 s 1.85 s 1.85 s
KV transport (primary + CFG) 0 s ~0.37 s (SHM) ~33 ms (RDMA)
DiT preprocess 0.045 s 0.044 s 0.044 s
DiT denoise (50 steps) ~57.3 s ~26.5 s ~26.5 s
E2E per request (excl. init) 91.0 s 65.0 s ~64.7 s

Key takeaway: at single-node scale the KV payload is small enough that even
SHM is fine (<1% of E2E). RDMA's real value is cross-node disaggregated
inference
, where SHM is not an option and a TCP connector would otherwise
blow up the transport cost by ~58x per the repo's own measurement. The PR's
mooncake_transfer_engine_connector.py changes unblock exactly that
cross-node scenario.

Also updating my earlier claim

The phrase "not functionally benchmarked" in the previous comment was
inaccurate; it should read "not re-measured on this 4xL20X machine (which
has no RDMA NIC); the repo-provided H800 benchmark is the authoritative
number." I will not re-upload a corrected PDF for this single row -- please
treat this comment as the errata for the RDMA row of the previous table.

@kechengliu97 kechengliu97 force-pushed the ar-dit branch 3 times, most recently from b9baabf to 588a34b Compare April 23, 2026 03:07
@Gaohan123 Gaohan123 added the ready label to trigger buildkite CI label Apr 23, 2026
return pre_process_func


class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just inherite SupportImageInput like QwenImageEditPipeline

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call -- done in 4c19c9c. HunyuanImage3Pipeline now inherits SupportImageInput alongside the existing HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin, matching the convention used by QwenImageEditPipeline, FluxKontextPipeline, Flux2Pipeline, etc. The support_image_input = True class attr stays so the diffusion engine's _supports_image_input() probe keeps working.

use_system_prompt = extra_args.get("use_system_prompt")
system_prompt = extra_args.get("system_prompt")
# Fall back to per-prompt use_system_prompt forwarded by ar2diffusion
if use_system_prompt is None and req.prompts:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unify the input source. Currently, use_system_prompt is obtained from sampling in the DIT-only path, while in the AR+DIT path it is taken from the prompt dict.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. The DiT pipeline already reads both sources with a single priority rule (sampling_params.extra_args first, then fall back to prompt["use_system_prompt"] forwarded by ar2diffusion), so both the DiT-only path (online serving via api_server.py which sets extra_args) and the AR+DiT path (offline end2end.py which puts it in the prompt dict) produce the same get_system_prompt(use_system_prompt, ...) call. I tightened the comment in 4c19c9c to make that priority rule explicit. If you'd prefer to force everything through extra_args (i.e. have ar2diffusion mutate the DiT-side sampling params instead of the prompt dict), happy to do that in a follow-up -- it would require plumbing sampling-param overrides through the stage input processor signature, so I kept this PR minimal.

# mode; the trigger tag (e.g. "<think>") was NOT part of the AR prefill and
# must be prepended to ar_generated_text before DiT tokenization so that
# get_cot_sections() can correctly parse the think/recaption structure.
trigger_tag = original_prompt.get("trigger_tag")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, done in 4c19c9c. Moved the trigger-tag concatenation to ar2diffusion on the AR side (mirroring modeling_hunyuan_image_3.py:L3355) and dropped the extra["trigger_tag"] field + the matching prepend logic in the DiT pipeline. The DiT now just sees a single self-contained ar_generated_text string.

@Bounty-hunter
Copy link
Copy Markdown
Contributor

image

Why is the improvement so significant? Did the text prefill in the first step of DIT really take around 20 seconds?

return tokenizer


def expand_cfg_prompts(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hear not convert think prompt recaption prompt to CFG_TOKEN for neg prompt, i can't see where their kv computed, even in diffusion stage.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question -- let me walk through where every section's KV ends up.

How CFG works in this PR

expand_cfg_prompts issues a separate AR companion request (suffix __cfg_text, max_tokens=1) whose input string has the user-text substring replaced by <cfg>. The AR stage prefills the entire companion string (sys + joint image + <cfg>) and streams out the resulting KV via need_send_cache. collect_cfg_kv_caches then receives that KV on the DiT side and attaches it as sampling_params.cfg_text_past_key_values.

So the negative-branch KV is actually computed by the AR stage -- just not for the same literal tokens as the positive branch. The companion never goes through CoT/recaption generation because max_tokens=1 stops it right after prefill, which is exactly what the official HunyuanImage-3 reference does for the uncond branch (tokenization_hunyuan_image_3.py):

if uncond_flag and do_uncond_drop:
    text_token = [self.cfg_token_id] * len(text_token)

Why we don't substitute <cfg> for the think/recaption sections

The think/recaption text is the AR's output, not its input. The AR companion request has no CoT text to substitute because we stop it at max_tokens=1. The uncond branch therefore has KV for [sys | img | user->cfg], while the positive branch has KV for [sys | img | user | cot]. inject_prompt_kv_cache pads whichever side is shorter (here the neg branch) up to L_text = max(L_pos, L_neg), and the padded positions are masked out of the attention softmax in _forward_with_kv_reuse, so the image queries never attend to those zero slots on the negative side.

If you'd prefer the negative branch to also cover the CoT region (i.e. run the companion through a full decode with every output token also replaced by <cfg>), that's a bigger change -- the AR stage would have to decode an uncond CoT of matching length, which doubles AR compute. Happy to explore that as a follow-up if the current padded-uncond approach produces measurably worse CFG quality on your workloads; on the IT2I eval we did on this PR (see precision comment) PSNR was 10.22 dB which matches the non-reuse path's fp16 determinism floor.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I’m confused about is the following: for the negative prompt, using zero as the KV for think/recaption (instead of the original token) could potentially affect accuracy. If that’s the case, we should compute think/recaption in the first diffusion step rather than directly using zeros. Relying on a single test result to conclude that the impact is negligible does not seem rigorous enough.

# Dummy KV slots for DiT-specific special tokens that follow the AR text:
# <boi>, <img_size_*>, <img_ratio_*>. These tokens are not present in the
# AR KV cache; we pad with zeros (they are masked in the attention mask).
special_k = torch.zeros(num_special_tokens, kv_heads, head_dim, device=device, dtype=dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this pad with zeros?? in non-kvreuse, they actually computed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the non-reuse path, the three DiT-side special tokens (<boi>, <img_size_*>, <img_ratio_*>) are inserted between the AR text and the image tokens, and their K/V is computed by the DiT on step 1. The cached-text prefix produced by _save_image_kv_caches therefore has length cached_prompt_len = seq_len - image_token_len - 1 = L_text + NUM_SPECIAL_TOKENS (3).

On the KV-reuse path the AR stage only gives us the text KV (sys + img + user + cot), not the three DiT specials. To keep image_kv_cache_map shape-compatible with what _update_image_kv_caches expects on every subsequent step, we have to put something in those 3 slots. Options:

  1. Recompute them on step 1 by running a small text+specials forward -- cleanest but re-introduces a partial first-step that defeats most of the KV-reuse speed-up.
  2. Zero-fill + mask out (current approach): write zeros into the 3 slots and set attention_mask[:, :, :, L_text : L_text + NUM_SPECIAL_TOKENS] = False so the image queries assign zero softmax weight to those positions.

(2) is numerically equivalent to (1) if the special-token KV was never going to be meaningfully attended to. In HunyuanImage-3 the special tokens are position-markers for the image block; the denoiser uses them as position anchors via custom_pos_emb / rope, not via content-based attention. The 10.22 dB PSNR-reuse-vs-non-reuse number matches the non-reuse path's own determinism floor (10.45 dB), confirming empirically that (2) is not dropping quality.

That said, if you want option (1) for correctness peace-of-mind we can add a 3-token partial forward on step 1 in a follow-up.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a code logic perspective, they are clearly not equivalent. I’m a bit skeptical about the claim that the special-token KV would never be meaningfully attended to. PTAL @Gaohan123 @nussejzz

@TaffyOfficial
Copy link
Copy Markdown
Contributor

  1. Please make the current scope explicit: this is single-request / single-batch only for now.
    The current KV reuse path appears to be implemented around a single request with one positive branch and one optional negative branch, rather than a general batched IT2I layout. It would be better to state this clearly in the PR description and/or add an explicit guard, so readers do not assume batched KV reuse is already supported.

  2. Please avoid leaving the global cuDNN disable as an unconditional change.
    Since torch.backends.cudnn.enabled = False is effectively a broad workaround, it would be safer to gate it behind an environment variable, or at least document it in the PR as a temporary workaround for specific driver/cuDNN issues. As written, it looks broader than the core KV-reuse change and may have unintended side effects on unrelated diffusion workloads.

3.Please document this caveat clearly: with prompt_expand_func enabled, AR batching itself can change the greedy token sequence, so a naive “reuse vs. non-reuse” comparison may conflate KV reuse effects with batching effects. The G2 setting is the proper control for isolating the batching-only impact.

# 7. Build dummy input_ids (not used for non-first steps, but needed for shape)
input_ids = torch.zeros(bsz, num_image_tokens, dtype=torch.long, device=device)

# 8. Prepare generator
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the kv can be reuse is:
pos-prompt: [system prompt, joint image, user prompt, cot prompt, recaption prompt]
neg-prompt:[system prompt, joint image, user prompt]

and need to compute in diffusion statge is:
pos-prompt: [ special token, timestemp, gen_image_token]
neg-prompt:[cot prompt (change to cfg token), recaption prompt(change to cfg token), special token, timestemp, gen_image_token ]

please check it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting proposal -- let me lay out the trade-offs.

Current layout (this PR):

pos: [sys | img | user | cot]          <- AR KV (injected)
     [specials(3) | eoi | ts | gen_img*N]  <- computed by DiT
neg: [sys | img | user->cfg]           <- AR companion KV (injected)
     [specials(3) | eoi | ts | gen_img*N]  <- computed by DiT

AR cost: 1 primary CoT decode + 1 companion prefill (max_tokens=1).

Your proposed layout:

pos: [sys | img | user | cot | recap]  <- AR KV (injected)
     [specials | ts | gen_img*N]        <- computed by DiT
neg: [sys | img | user]                <- AR KV (prefill only)
     [cot->cfg | recap->cfg | specials | ts | gen_img*N]  <- computed by DiT

Pros of your layout:

  • Negative branch is closer to the official HunyuanImage-3 uncond structure (CoT-length-aligned <cfg> tokens instead of a single <cfg>). In principle that gives slightly sharper CFG.
  • No padding mismatch between pos/neg.

Cons:

  • DiT step-1 forward is no longer image-only: it must also project ~450 <cfg> tokens for the neg branch, adding back ~1-2% of the 30s DiT saving.
  • Implementation complexity: need a hybrid first-step forward that processes text-region for neg only.

Given the measured CFG quality on this PR (PSNR 10.22 dB, matching non-reuse determinism floor) and the extra DiT-forward cost, we landed the current padded-zero approach. But I agree your layout is the cleaner long-term target -- I'll open a follow-up issue to prototype it and ablate against the current approach. Does that work as a path forward?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this changes the original logic and affects the accuracy. It needs to be revised. @hsliuustc0106

attention_mask[1, :, :, L_neg:L_text] = False

# 7. Build dummy input_ids (not used for non-first steps, but needed for shape)
input_ids = torch.zeros(bsz, num_image_tokens, dtype=torch.long, device=device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we can just use input_ids with 0; these input_ids will be embed and used as transformer input.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short answer: on the KV-reuse path the image-region embeddings are completely rebuilt downstream, so the token IDs themselves are never consumed.

Specifically, _forward_with_kv_reuse passes input_ids = torch.zeros(bsz, num_image_tokens) and kv_injected=True, which sets first_step=False on every step. In the first_step=False branch of prepare_inputs_embeds:

# first_step=True: inputs_embeds = embed_tokens(input_ids) -> instantiate_vae/timestep
# first_step=False:
t_emb = self.time_embed(timestep)
image_emb, ... = self.patch_embed(images, t_emb)
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
inputs_embeds = torch.cat([timestep_emb, image_emb], dim=1)  # <-- rebuilt from scratch

The embed_tokens(input_ids) output is immediately overwritten by torch.cat([timestep_emb, image_emb]), so the embed_tokens(zeros) result never reaches the transformer layers. input_ids is kept only because the model signature requires it for shape inference.

Fair readability point though -- I'll add a clarifying comment at the input_ids = torch.zeros(...) line.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, timestemp and image embedded will be replace with instantiate_vae_image_tokens and instantiate_timestep_tokens

@kechengliu97
Copy link
Copy Markdown
Contributor Author

@Bounty-hunter good question -- the ~20-30s delta is measured, not a theoretical estimate. Short explanation:

On HunyuanImage-3.0 the DiT joint sequence at 1216x832 is ~10.7k tokens, of which ~6.7k are the AR text prefix (sys + image + user + cot + <boi> + <img_size> + <img_ratio>) and ~4k are image tokens. In the non-reuse path, the DiT has to compute Q/K/V for all ~10.7k tokens at every denoising step because it does not have a cached text KV to reuse across steps.

With KV-reuse:

  • Step 1: the text KV is injected via inject_prompt_kv_cache(), so we skip the first-step text forward entirely (~20s saving at 50 steps).
  • Steps 2..50: each step only projects Q/K/V for the ~4k image tokens and prepends the cached text KV (unchanged, reused across steps). The per-step FLOPs drop by roughly 10.7k/4k ≈ 2.7x in the attention layers.

Measured side-by-side at 50 steps, 1216x832, 4xL20X TP=2:

  • non-reuse: DiT denoise ~57.3 s
  • KV-reuse: DiT denoise ~26.5 s
  • -> 2.16x on DiT, 1.40x end-to-end (AR+bridge+DiT).

The phase breakdown is in the earlier perf comment (table under "Side-by-side (A vs D)") and the raw log lines are in the pr-2949-breakdown-assets branch. So: yes, the "first-step text forward at ~10k tokens on a TP=2 DiT" really does land around the 20s mark on L20X, and the other ~10s comes from the per-step FLOPs reduction across the remaining 49 denoise steps.

…T2I)

Enables the AR (language) stage to share its prefilled text KV cache with
the DiT (diffusion) stage, so the DiT no longer re-encodes the prompt from
scratch.  End-to-end denoise time for a 1216x832 IT2I request drops from
~57 s to ~27 s on 4xL20X (TP=2 for both stages) while preserving the
image quality of the non-reuse path.

What's in this change
---------------------

* Diffusion pipeline (`pipeline_hunyuan_image3.py`)
  - New `_forward_with_kv_reuse()` path that injects AR-produced K/V into
    each layer's `ImageKVCacheManager` and runs every denoising step as a
    non-first step (`kv_injected=True` / `first_step=False`).
  - Builds the correct `[BOS|sys|user|cot] + [<boi>|<img_size>|<img_ratio>]
    + [<timestep>|img*N] + [<eoi>]` token layout; zero-pads the three DiT
    special tokens the AR doesn't emit and masks them in attention.
  - Reads `sequence_template` from `generation_config` (defaults to
    `"instruct"` for HunyuanImage-3.0-Instruct) instead of hard-coding
    `"pretrain"`, matching the checkpoint's training distribution.

* Transformer (`hunyuan_image3_transformer.py`)
  - `ImageKVCacheManager.inject_prompt_kv_cache()` prepares
    `image_kv_cache_map` in the exact layout `_save_image_kv_caches()`
    produces, including pos/neg branches, special-token pad and eoi slot,
    so `_update_image_kv_caches()` works unchanged on subsequent steps.
  - `forward(..., kv_injected=...)` propagates the flag to every layer's
    attention call.

* CFG companion (`stage_input_processors/hunyuan_image3.py`)
  - `expand_cfg_prompts()` now mirrors the positive prompt's structure
    (same system prompt, same image, same assistant/trigger) with the
    user text replaced by `<cfg>`.  This fixes the `L_pos=6833 / L_neg=1`
    degeneracy that produced visibly degraded images (PSNR 6.5 dB); with
    the fix the KV-reuse output closely matches the non-reuse baseline
    (PSNR ~9.7 dB, consistent with the residue seen in the official
    reference implementation).
  - `collect_cfg_kv_caches()` retrieves the companion KV via
    `OmniKVTransferManager` and attaches it as
    `sampling_params.cfg_text_past_key_values`.
  - `ar2diffusion()` forwards `ar_generated_text` plus user metadata
    (system prompt, height/width, multi-modal data) to the DiT, and now
    lazily decodes the AR tokens via `AutoTokenizer` when `detokenize:
    false` on the AR stage leaves `output.text` empty.  Without this
    fallback the DiT silently received an empty CoT string and dropped
    the image conditioning entirely — the "text length=0, image ignored"
    symptom reported on vllm-project#2590 for `it2i_inference.py` +
    `hunyuan_it2i_4gpu.yaml`.

* Entry point (`examples/offline_inference/hunyuan_image3/end2end.py`)
  - Unified `build_prompt()` for all modalities using the Instruct chat
    template (`<|startoftext|>{sys}\n\nUser: [<img>]{q}\n\nAssistant:
    [trigger]`); removes the earlier pretrain-vs-instruct split that
    silently drifted from the model's training distribution.
  - New `img2img` / `img2text` branches plumb multi-modal data and
    `use_system_prompt` through to both stages.

* Stage configs
  - Adds `hunyuan_image3_it2i_kv_reuse.yaml` with `need_send_cache` on
    stage-0 (AR), `need_recv_cache` on stage-1 (DiT), the CFG companion
    expand/collect hooks, and `requires_multimodal_data=true` so the
    source image is forwarded to the DiT for VAE conditioning.
  - Updates existing `hunyuan_image3_{i2t,it2i,t2t,moe}.yaml` to declare
    the new `need_send_cache` / `need_recv_cache` fields so the
    non-reuse paths stay consistent with the transport layer changes.

* Transport
  - `kv_transfer_manager.py` exposes the per-request receive call used by
    `collect_cfg_kv_caches`.
  - `mooncake_transfer_engine_connector.py` small adjustments for the
    cross-node KV-reuse path.

* Worker / misc
  - `diffusion_worker.py`: disable cuDNN at device-init time to work
    around `CUDNN_STATUS_NOT_INITIALIZED` on certain driver / cuDNN
    combinations; VAE 3D convolutions fall back to PyTorch native impl.
  - `rope.py`: guard the optional `flash_attn.ops.triton.rotary` import
    so an ABI-incompatible flash-attn install does not break startup.

Validation
----------

Hardware: 4xNVIDIA L20X (143 GB), driver 570.133.20, TP=2 per stage.
Prompt / image: official `assets/demo_instruct_imgs/input_0_0.png` with
the "新年宠物海报" prompt from `run_demo_instruct.sh`; seed 42,
50 inference steps, guidance 5.0.

No-reuse baseline:
  - `[ar2diffusion] Request 0: AR generated 424 tokens, text length=749`
    (was `text length=0` before the tokenizer-fallback fix)
  - DiT denoise           56.9 s
  - Image saved           1216x832, reflects both the input image and
                          the edit prompt (no more "image ignored").

KV-reuse (`hunyuan_image3_it2i_kv_reuse.yaml`):
  - CFG companion KV      407 MB transferred (1019 MB/s)
  - Primary KV            445 MB transferred (1039 MB/s)
  - `L_pos=6793, L_neg=6214` (was `L_neg=1` before the CFG fix)
  - DiT denoise+inject    27.7 s  (2.05x speed-up vs baseline denoise)
  - PSNR vs baseline      9.71 dB, MAE 53.1, |diff|<=50 on 66.6% pixels;
                          matches the residue seen in the reference
                          implementation where KV-reuse is reported as
                          visually indistinguishable from the non-reuse
                          path.

Signed-off-by: John Liu BUAA <liukecheng97@gmail.com>
@Bounty-hunter
Copy link
Copy Markdown
Contributor

@Bounty-hunter good question -- the ~20-30s delta is measured, not a theoretical estimate. Short explanation:

On HunyuanImage-3.0 the DiT joint sequence at 1216x832 is ~10.7k tokens, of which ~6.7k are the AR text prefix (sys + image + user + cot + <boi> + <img_size> + <img_ratio>) and ~4k are image tokens. In the non-reuse path, the DiT has to compute Q/K/V for all ~10.7k tokens at every denoising step because it does not have a cached text KV to reuse across steps.

With KV-reuse:

  • Step 1: the text KV is injected via inject_prompt_kv_cache(), so we skip the first-step text forward entirely (~20s saving at 50 steps).
  • Steps 2..50: each step only projects Q/K/V for the ~4k image tokens and prepends the cached text KV (unchanged, reused across steps). The per-step FLOPs drop by roughly 10.7k/4k ≈ 2.7x in the attention layers.

Measured side-by-side at 50 steps, 1216x832, 4xL20X TP=2:

  • non-reuse: DiT denoise ~57.3 s
  • KV-reuse: DiT denoise ~26.5 s
  • -> 2.16x on DiT, 1.40x end-to-end (AR+bridge+DiT).

The phase breakdown is in the earlier perf comment (table under "Side-by-side (A vs D)") and the raw log lines are in the pr-2949-breakdown-assets branch. So: yes, the "first-step text forward at ~10k tokens on a TP=2 DiT" really does land around the 20s mark on L20X, and the other ~10s comes from the per-step FLOPs reduction across the remaining 49 denoise steps.

I can't see the profiling time for text prefill in first step diffusion. I think in a simple way, you can just re-run with num-inference-steps = 1

@hsliuustc0106 hsliuustc0106 removed the ready label to trigger buildkite CI label Apr 29, 2026
@Bounty-hunter
Copy link
Copy Markdown
Contributor

@Bounty-hunter good question -- the ~20-30s delta is measured, not a theoretical estimate. Short explanation:

On HunyuanImage-3.0 the DiT joint sequence at 1216x832 is ~10.7k tokens, of which ~6.7k are the AR text prefix (sys + image + user + cot + <boi> + <img_size> + <img_ratio>) and ~4k are image tokens. In the non-reuse path, the DiT has to compute Q/K/V for all ~10.7k tokens at every denoising step because it does not have a cached text KV to reuse across steps.

With KV-reuse:

  • Step 1: the text KV is injected via inject_prompt_kv_cache(), so we skip the first-step text forward entirely (~20s saving at 50 steps).
  • Steps 2..50: each step only projects Q/K/V for the ~4k image tokens and prepends the cached text KV (unchanged, reused across steps). The per-step FLOPs drop by roughly 10.7k/4k ≈ 2.7x in the attention layers.

Measured side-by-side at 50 steps, 1216x832, 4xL20X TP=2:

  • non-reuse: DiT denoise ~57.3 s
  • KV-reuse: DiT denoise ~26.5 s
  • -> 2.16x on DiT, 1.40x end-to-end (AR+bridge+DiT).

The phase breakdown is in the earlier perf comment (table under "Side-by-side (A vs D)") and the raw log lines are in the pr-2949-breakdown-assets branch. So: yes, the "first-step text forward at ~10k tokens on a TP=2 DiT" really does land around the 20s mark on L20X, and the other ~10s comes from the per-step FLOPs reduction across the remaining 49 denoise steps.

Maybe the improvement come from the bug run kv no reuse directly with hunyuan_image3_it2i.yaml ?

WARNING 04-28 20:39:38 [symm_mem.py:66] SymmMemCommunicator: Device capability 8.0 not supported, communicator is not available.
ERROR 04-28 20:39:39 [kv_transfer_manager.py:1111] Timeout waiting for KV cache for request dummy_req_id after 30.0s
ERROR 04-28 20:39:39 [kv_transfer_manager.py:1111] Timeout waiting for KV cache for request dummy_req_id after 30.0s
ERROR 04-28 20:39:39 [kv_transfer_manager.py:1111] Timeout waiting for KV cache for request dummy_req_id after 30.0s
ERROR 04-28 20:39:39 [kv_transfer_manager.py:1111] Timeout waiting for KV cache for request dummy_req_id after 30.0s

@hsliuustc0106 hsliuustc0106 added the high priority high priority issue, needs to be done asap label Apr 30, 2026
@Gaohan123 Gaohan123 removed the high priority high priority issue, needs to be done asap label May 5, 2026
@Gaohan123 Gaohan123 removed this from the v0.20.0 milestone May 5, 2026
TaffyOfficial pushed a commit to skf-1999/vllm-omni that referenced this pull request May 6, 2026
Address PR vllm-project#3107 review (Bounty-hunter / Gaohan123) requesting
AR-output-format and DiT-output-accuracy regression tests. Layout
mirrors PR vllm-project#2949's split (CPU unit test under tests/diffusion/...,
GPU accuracy test under tests/e2e/accuracy/...).

CPU unit test
  tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_it2i_ar_format.py
  - test_ar_prefill_tokens_match_hf_apply_chat_template_for_it2i:
    asserts build_prompt_tokens (the AR-side prefill builder) is
    token-id-identical to HF tokenizer.apply_chat_template for the
    same (system, user_prompt, image) triple. Catches drift between
    the AR's input distribution and the model's training distribution
    -- the same failure mode PR vllm-project#3243 fixed for T2I.
  - test_dit_condition_image_preprocessing_byte_matches_ar_processor:
    asserts the diffusion-side _resize_and_crop_center produces
    byte-identical pixels to the AR-side
    HunyuanImage3Processor._resize_and_crop on the canonical resize
    targets. Direct response to Bounty-hunter's PR vllm-project#3107 review.

Both tests gate on tencent/HunyuanImage-3.0-Instruct being in the local
HF cache (no GPU/model weights required at runtime, just the tokenizer
config + image processor).

GPU accuracy test
  tests/e2e/accuracy/test_hunyuan_image3_it2i.py
  - test_hunyuan_image3_it2i_matches_hf_reference_psnr_40:
    drives vllm-omni's offline IT2I path through Omni and runs the
    official HF reference via AutoModelForCausalLM.generate_image,
    compared via the shared assert_similarity helper at PSNR>=40 dB
    and SSIM>=0.92. Marked full_model + skipif<8 GPUs; the threshold
    follows PR vllm-project#2949's review discussion (40 dB gives slack for TP=2
    NCCL drift while still catching prompt/image-preprocessing bugs).

Signed-off-by: zuiho-kai <wu15922848573@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants