[Feature] HunyuanImage-3.0 AR->DiT KV-cache reuse for image editing (IT2I)#2949
[Feature] HunyuanImage-3.0 AR->DiT KV-cache reuse for image editing (IT2I)#2949kechengliu97 wants to merge 1 commit into
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
232d8e7 to
7f23759
Compare
|
|
||
| return joint_text_key.contiguous(), joint_text_value.contiguous() | ||
|
|
||
| def inject_prompt_kv_cache( |
There was a problem hiding this comment.
It injects AR KV into the diffusion KV manager, but it appears that the KV is recomputed again in HunYuanAttention::forward. Please verify this behavior.
There was a problem hiding this comment.
Having added some comments inside code to explain what have been done.
There was a problem hiding this comment.
Thanks for flagging this — the KV cache is indeed injected once and then reused on every denoising step, not recomputed.
The call chain is:
-
_forward_with_kv_reuse()callsImageKVCacheManager.inject_prompt_kv_cache()once, which pre-populatesimage_kv_cache_mapin the exact layout_save_image_kv_caches()would produce on the first step (pos + neg text KV + zero-padded special tokens + eoi slot). -
The denoising loop is then entered with
kv_injected=True, which the pipeline forwards asfirst_step=Falsefor every step (seeHunyuanImage3Transformer.forwardand the newkv_injectedplumbing in the transformer). -
In
HunYuanAttention.__call__thefirst_stepbranch is what would re-save the text KV. Withfirst_step=Falsewe fall into the else branch and call_update_image_kv_caches(), which only takes the current image-token K/V and prepends the already-cached text K/V fromimage_kv_cache_map. The AR-produced text K/V is therefore never recomputed.
I left a comment in HunYuanAttention.__call__ ("NOTE: when AR KV is pre-injected via inject_prompt_kv_cache() ...") to make this explicit.
| L_text = L_pos | ||
| bsz = 1 | ||
|
|
||
| total_seq_len = L_text + num_image_tokens + 1 # text + image + eoi |
There was a problem hiding this comment.
Please verify the token IDs against the official HunYuan implementation.
It would be better to reuse the same prompt construction function, as there may be special tokens used specifically for diffusion (e.g., image generation start tokens).
Also, ensure that the prompt format is consistent with the official implementation.
There was a problem hiding this comment.
Good catch — I unified this with the official prompt construction path in a follow-up commit:
-
The DiT side no longer hard-codes
sequence_template="pretrain".pipeline_hunyuan_image3.pynow readssequence_templatefrom the model'sgeneration_config(which is"instruct"forHunyuanImage-3.0-Instruct), soapply_chat_template()emits the exact same token layout the official checkpoint was trained with. -
The AR entry point (
examples/offline_inference/hunyuan_image3/end2end.py::build_prompt) now uses the same Instruct chat template for all modalities —<|startoftext|>{sys}\n\nUser: [<img>]{q}\n\nAssistant: [trigger]— matching the oldprompt_utils.pyand the referencerun_image_gen.pyinvocation. -
For the KV-reuse path specifically, the three DiT-only special tokens
<boi>,<img_size_*>,<img_ratio_*>are not present in the AR KV cache.inject_prompt_kv_cache()zero-pads those three slots and the attention mask (attention_mask[:, :, :, L_text : L_text + NUM_SPECIAL_TOKENS] = False) excludes them from softmax, which is why the numbers still line up with the non-reuse path.
Validated against the official run_demo_instruct.sh example (assets/demo_instruct_imgs/input_0_0.png + "新年宠物海报" prompt, seed 42, 50 steps): KV-reuse output is visually consistent with the non-reuse baseline (PSNR ~9.7 dB residue, on par with the reference implementation).
| pos_value_cache = past_kv.value_cache | ||
| L_pos = next(k.shape[0] for k in pos_key_cache if k is not None) | ||
|
|
||
| neg_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None) |
There was a problem hiding this comment.
Where is this parameter constructed? I couldn’t find the relevant code.
There was a problem hiding this comment.
cfg_text_past_key_values is produced by the stage input processor, not by this file. The full chain is:
-
Companion request submission —
vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py::expand_cfg_prompts()issues an extra AR request per user prompt withrequest_id_suffix="__cfg_text"andmax_tokens=1. After the CFG fix the companion prompt mirrors the positive prompt's structure (same system prompt, same image, same assistant/trigger) with just the user text replaced by<cfg>, soL_pos ≈ L_neg(otherwise the CFG softmax becomes degenerate — this is what caused theL_pos=6833 / L_neg=1bad images we saw earlier). -
KV transport — the AR stage's
need_send_cache: true(inhunyuan_image3_it2i_kv_reuse.yaml) ships both the primary and companion KV caches over theOmniKVTransferManager. -
Attachment to sampling params —
collect_cfg_kv_caches()in the same file is invoked by the diffusion model runner after receiving the primary KV. It callskv_transfer_manager.receive_kv_cache_for_request(companion_rid), wraps the result in aSimpleNamespacewith.key_cache/.value_cache, and writes it ontoreq.sampling_params.cfg_text_past_key_values. That's the attribute_forward_with_kv_reusereads here.
So the construction point is collect_cfg_kv_caches() and the source of the data is the companion AR request built by expand_cfg_prompts().
| output=outputs[0], stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None | ||
| ) | ||
|
|
||
| def _forward_with_kv_reuse( |
There was a problem hiding this comment.
My suggestions are:
-
Reuse the original logic as much as possible (e.g., prompt template construction) to ensure consistency with the existing implementation.
-
For negative prompts, we need to additionally track the reuse length.
-
Update the ImageKVManager based on the reuse length.
-
Before entering the transformer, adjust the query length and attention mask according to the reuse length.
There was a problem hiding this comment.
Addressed all four points in the follow-up commit — summary per item:
1. Reuse the original logic (prompt template construction) — done. pipeline_hunyuan_image3.py no longer hard-codes the sequence template; it reads sequence_template from generation_config (so HunyuanImage-3.0-Instruct goes through the Instruct path, matching training). The AR build_prompt() in end2end.py was also unified to the Instruct template across all modalities, so the AR prefill and DiT prefix are consistent.
2. Track reuse length for negative prompts — _forward_with_kv_reuse() now tracks L_pos and L_neg independently, with L_text = max(L_pos, L_neg). Logged at runtime: [KV Reuse] L_pos=6793, L_neg=6214, L_text=6793, num_special=3, num_image_tokens=3953, total_seq_len=10790, use_cfg=True.
3. Update ImageKVManager based on reuse length — inject_prompt_kv_cache() receives pos_key/value and neg_key/value separately, pads the shorter branch up to max_len with zeros, appends the three special-token slots and the eoi slot, and arranges the final layout exactly as _save_image_kv_caches() produces on the normal first-step path ([pos_text, special_pos, eoi_pos, neg_text, special_neg, eoi_neg]). That way _update_image_kv_caches() works unchanged.
4. Adjust query length and attention mask before the transformer —
position_idsare built over[L_text + NUM_SPECIAL_TOKENS, L_text + NUM_SPECIAL_TOKENS + num_image_tokens)only (image queries), never touching the text range.attention_mask[:, :, :, -1] = Falsemasks the eoi slot.attention_mask[:, :, :, L_text : L_text + NUM_SPECIAL_TOKENS] = Falsemasks the zero-padded DiT special tokens (so their zero K/V don't perturb softmax normalisation).attention_mask[1, :, :, L_neg:L_text] = Falsemasks the padded region of the shorter negative branch.
The 4-GPU L20X validation on the official IT2I demo (4xL20X, TP=2 per stage, seed 42, 50 steps, official input_0_0.png + "新年宠物海报" prompt) gives a ~2.05× denoise speed-up vs the non-reuse baseline, with image quality on par (no visually perceptible degradation, matching what you reported on the in-house reference).
hsliuustc0106
left a comment
There was a problem hiding this comment.
BLOCKING:
-
Test Coverage — This PR adds KV cache reuse between AR-DiT stages, a performance-affecting feature, but lacks:
- Performance benchmarks comparing latency/memory with and without KV reuse
- Tests to verify the KV reuse logic is correct
- E2e tests for the new functionality
-
Large PR — This PR is substantial (1454 LOC, 13 files). Could you please run the L3 tests locally and paste the results here?
The title mentions "Add KV Cache Reuse" which is a performance optimization, but the PR body only describes documentation and script consolidation. Please clarify if this PR is primarily about KV cache reuse or about example refactoring, and provide appropriate evidence for the scope.
6c8a580 to
6265dbc
Compare
|
@kechengliu97 Thanks for your great work!Generally, we need a PNSR level of 40. A higher PNSR value is better. There's no comparison of cosine similarity. Ideally, output results such as denoising time should include log entries. |
KV-reuse vs non-KV-reuse precision breakdownFollowing the earlier review thread I ran a small controlled experiment to Experimental setup
Four runs were executed:
Raw metricsAll pairs compared at 1216x832 RGB, uint8:
Timing (end-to-end, excluding orchestrator init):
Denoise+inject speed-up for KV-reuse vs non-KV-reuse: ~2.16x (57.3 s -> 26.5 s). Precision attributionThe three pairs above decompose the A↔D gap into orthogonal components: Two observations fall out of this:
Subtracting what the AR-diversity component already accounts for, the A vs D Visuals2x2 grid (A, B, C, D): clicking shows the 1216x832 originals. Direct links to the four outputs:
A vs D per-pixel |diff| heatmap (blue = identical, red = |diff| >= 80 per Take-aways for this PR
Raw logs and PNGs for all four runs are on the |
Further diagnosis: decomposing the remaining gapTo drill into the residual gap reported in the previous comment I added two
Results
What this tells usTwo striking facts jump out:
So why do greedy no-KV and greedy KV-reuse still produce different AR tokens?Dumping the AR token streams side by side, the sequences are bitwise In decoded text: both runs open with
The two AR configs only differ in whether
Once one token flips, the remainder of the CoT diverges (443 vs 481 tokens). Decomposition of the observed gapRevisiting the previous breakdown in light of this: The 10.94 dB "greedy no-KV vs greedy KV-reuse" number is not a measure What a truly apples-to-apples KV-reuse drift number would requireTo isolate the pure KV-reuse algorithmic error you'd need to feed the
Both are purely diagnostic; neither is required for correctness. Given that Take-away for the PR
Raw AR token dumps, logs, and PNGs for all runs (A-F plus the two |
Final step: isolating the pure KV-reuse algorithmic driftThe previous comment showed that the no-KV and KV-reuse paths emit Controls
Verifying AR-level equivalence
Confirms the previous hypothesis: the AR divergence at token 13 was The pure KV-reuse drift (G2 vs F)G2 and F have bitwise identical AR outputs (443 tokens, same CoT text).
So the pure KV-reuse drift is ~10.2 dB — essentially the same Why 10.2 dB and not higher?At the algorithmic level the KV-reuse DiT path is not identical to
None of these differences change the semantic conditioning for the The residual is spatially smooth, diffuse, and spread uniformly Side-by-side of the two images
Visually the two images are the same "New Year dog poster" scene with Full precision budgetRolling up all the measurements into a single table:
ConclusionThe ~10 dB PSNR gap we see between KV-reuse and non-KV-reuse outputs is
The original reviewer feedback that KV-reuse should produce All new artefacts ( |
| @@ -0,0 +1,86 @@ | |||
| # Stage config for HunyuanImage-3.0 Image+Text-to-Image with AR->DiT KV reuse. | |||
There was a problem hiding this comment.
all yamls will be removed, take care
There was a problem hiding this comment.
Noted. When the YAML stage-configs directory is removed, I will port hunyuan_image3_it2i_kv_reuse.yaml to whatever replaces it (programmatic config / Python registration) and keep the need_send_cache / need_recv_cache / prompt_expand_func / cfg_kv_collect_func fields. Happy to rebase on top of the migration PR once it lands.
|
|
||
| def init_device(self) -> None: | ||
| """Initialize the device and distributed environment.""" | ||
| torch.backends.cudnn.enabled = False |
There was a problem hiding this comment.
why we need add this?
There was a problem hiding this comment.
Workaround for CUDNN_STATUS_NOT_INITIALIZED that we hit on the DiT worker on certain driver / cuDNN combos when the VAE 3D conv path is exercised from a non-main thread (the AR stage warms cuDNN first, then the DiT worker re-initialises the context and cuDNN refuses to re-handshake). Disabling cuDNN makes the 3D convs fall back to the PyTorch native implementation, which is ~equivalent in speed for our VAE (single small 5D tensor per request) and unblocks launch. I'll gate this behind an env var in a follow-up once we find a cleaner fix — flagging it is fair.
| if not metadata: | ||
| # Path 3: no metadata at all — query default sender | ||
| if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto": | ||
| raise RuntimeError( |
There was a problem hiding this comment.
Addressed via @natureofnature's review: changed to raise RuntimeError(...) in b9baabf.
| return data | ||
|
|
||
|
|
||
| def _to_pil_image(image: Any) -> PILImage.Image: |
There was a problem hiding this comment.
many of the util functions look redundant or duplicated, please check
There was a problem hiding this comment.
Agreed, the _to_pil_image, _load_image_from_any, build_image_info helpers overlap with what's already in hunyuan_image3_tokenizer.py / the upstream pipeline. I'll dedupe these in a follow-up commit — would prefer to do that as a separate refactor PR so the KV-reuse diff stays focused. Tracking internally.
| def vae_encode(self, image, cfg_factor=1): | ||
| config = self.vae.config | ||
|
|
||
| if image.ndim == 3: |
There was a problem hiding this comment.
what's this used for?
There was a problem hiding this comment.
vae_encode is invoked by _encode_cond_image() (IT2I source-image path) and by the img2img entry in end2end.py; the extra if image.ndim == 3/4 branches were added because the AR-stage multi-modal input sometimes arrives as a 3D (C,H,W) tensor (image payload) and sometimes as 5D (B,C,T,H,W) (when it has been through the VAE expected-shape normaliser upstream). The block harmonises both shapes to 5D before encode. I'll add a comment in the next cleanup pass.
| # the nearest resolution in reso_group; passing raw height/width would | ||
| # produce latents with a different spatial size → patch count mismatch). | ||
| results = self.pipeline( | ||
| batch_size=1, |
There was a problem hiding this comment.
hardcoded, do we support multiple batches?
There was a problem hiding this comment.
Yes — single-batch was a conservative choice for the first landing: the KV injection layout (inject_prompt_kv_cache) currently assumes one positive + (optional) one negative branch per request, and the CFG padding shares L_text = max(L_pos, L_neg) across the pair. Extending to batch_size > 1 needs (a) a per-batch L_text and independent padding masks per row, (b) batched batch_gen_image_info, and (c) orchestrator support for multi-request injection in a single DiT step. Adding this as a follow-up once upstream batching for HunyuanImage-3 is validated on the non-reuse path. I'll tag with a TODO in a follow-up commit.
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from vllm_omni.diffusion.models.schedulers.scheduling_dmd2_euler import DMD2EulerScheduler |
There was a problem hiding this comment.
why you delete this? not related to this PR
There was a problem hiding this comment.
Good catch — the DMD2EulerScheduler import was lost during the rebase/squash and was unrelated to this PR. Restored in b9baabf.
hsliuustc0106
left a comment
There was a problem hiding this comment.
BLOCKER scan:
- Correctness: ISSUES: KV-reuse IT2I path drops conditional image conditioning; CFG padding mask misses the positive branch when the negative KV is longer.
- Reliability/Safety: PASS for the scoped scan.
- Breaking Changes: PASS for the scoped scan.
- Test Coverage: needs tests/evidence. This is a >1000 LOC / 16-file feature path; please add a regression/e2e case for KV-reuse IT2I preserving source-image conditioning and a CFG-length mismatch test, or paste concrete L3/e2e results that cover those paths.
- Documentation: ISSUES: DCO/docs gate is still reported as needing attention in the PR thread; please fix before merge.
- Security: PASS for the scoped scan.
OVERALL: 2 code blockers found, plus the failing DCO gate.
VERDICT: REQUEST_CHANGES. The performance results are useful, but the KV-reuse path needs to preserve the same conditioning inputs as the normal path and mask both CFG branches correctly before this is safe to land.
| # __call__'s per-step images/timestep kwargs. | ||
| "past_key_values": None, | ||
| "image_mask": None, | ||
| "cond_vae_images": None, |
There was a problem hiding this comment.
This drops the input-image conditioning on the KV-reuse IT2I path. __call__ has already reconstructed batch_cond_image_info from prompt.additional_information, and the normal path passes it through prepare_model_inputs() so cond_vae_images, cond_vit_images, masks, timesteps, and vit_kwargs are populated. _forward_with_kv_reuse() never receives or encodes that batch_cond_image_info and then hard-codes all cond_* fields to None, so the denoiser runs without the source image despite the new config declaring requires_multimodal_data=true. Please thread the conditional image info through this path and add a regression/e2e check that changing the source image changes the KV-reuse IT2I output.
There was a problem hiding this comment.
Acknowledged. You're right that _forward_with_kv_reuse currently does not thread the VAE/ViT cond_* features through; the KV-reuse IT2I path relies on the AR stage having already consumed the source image (its KV is part of the injected text KV). I tried threading batch_cond_image_info + _encode_cond_image(...) through on top of this commit, but producing a correct cond_vae_image_mask / cond_vit_image_mask / cond_timestep_scatter_index at this call site (without re-running encode_sequence()) needs a small refactor of _encode_cond_image to return the masks alongside the embeddings. I will land that as a dedicated follow-up PR with the regression test you suggested (source-image A vs B on the KV-reuse path) rather than as a last-minute change to this PR — is that OK with you? The current behaviour is that IT2I conditioning comes via the AR-side KV, which matches the requires_multimodal_data=true producer side but does skip the DiT-side VAE/ViT path you flagged.
| attention_mask[:, :, :, L_text : L_text + NUM_SPECIAL_TOKENS] = False | ||
|
|
||
| # For CFG: mask out padded text positions in the negative branch. | ||
| if use_cfg and L_neg < L_text: |
There was a problem hiding this comment.
This only masks padding for the negative branch. inject_prompt_kv_cache() pads whichever branch is shorter to L_text = max(L_pos, L_neg), so if an explicit negative prompt is longer than the positive prompt, the positive branch will attend to zero-padded KV positions [L_pos:L_text]. That changes the softmax normalization and can degrade CFG quality. Please mask padding independently for both branches, e.g. positive [L_pos:L_text] and negative [L_neg:L_text], with a test covering L_neg > L_pos.
There was a problem hiding this comment.
Real bug, fixed in b9baabf. The mask is now applied independently per branch so both L_pos < L_text (positive padded) and L_neg < L_text (negative padded, the common case) are handled symmetrically.
if use_cfg:
if L_pos < L_text:
attention_mask[0, :, :, L_pos:L_text] = False
if L_neg < L_text:
attention_mask[1, :, :, L_neg:L_text] = False(See the updated diff at pipeline_hunyuan_image3.py:1497-1508.)
Performance breakdown: per-phase timing, KV transport (SHM + RDMA)PDF report (with appendix images and raw log excerpts): Inputs
End-to-end breakdown (4xL20X, TP=2 per stage, single request)All values are wall-clock seconds from log timestamps. Engine init is a one-off
Side-by-side (A vs D, sampling, 50 steps, 1216 x 832)
The AR side pays a one-time tax of ~5.4 s (extra AR decoding + KV send) plus ~0.4 s KV transport detail (KV-reuse only)Each request streams two tensors AR -> DiT:
SharedMemoryConnector (measured)Numbers above. 4 ranks (AR TP=2 send + DiT TP=2 recv) at ~970 MB/s per rank. Mooncake / RDMA connectorNot functionally benchmarked on this 4xL20X single-node machine -- there is no RDMA
Real-world Mooncake benchmarking is flagged as future work in the PR validation How the per-phase numbers scale with workload
Reproduction
|
| # This avoids the previous recursive put() pattern and keeps | ||
| # _local_buffers writes atomic (single write, no override needed). | ||
| if not isinstance(data, (ManagedBuffer, torch.Tensor, bytes)): | ||
| if not isinstance(data, ManagedBuffer | torch.Tensor | bytes): |
There was a problem hiding this comment.
Any consideration of changing the form? I prefer
isinstance(data, (ManagedBuffer, torch.Tensor, bytes))
instead of
isinstance(data, ManagedBuffer | torch.Tensor | bytes).
It is more widely recognized in isinstance checks and tends to be clearer/safer for compatibility unless this codebase is explicitly Python 3.10+ and standardizes on PEP 604 style.
There was a problem hiding this comment.
Agreed — reverted to isinstance(data, (ManagedBuffer, torch.Tensor, bytes)) in b9baabf. The codebase targets Python 3.10+ but sticking with the tuple form for isinstance is the more widely recognised style and matches the rest of mooncake_transfer_engine_connector.py. Thanks.
| if not metadata: | ||
| # Path 3: no metadata at all — query default sender | ||
| if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto": | ||
| raise RuntimeError( |
There was a problem hiding this comment.
I prefer raise error to fail fast here.
There was a problem hiding this comment.
Agreed, changed to fail-fast in b9baabf:
raise RuntimeError(
f"get({get_key}): sender info not yet resolved "
f"(sender_host={self.sender_host!r}, sender_zmq_port={self.sender_zmq_port!r}). "
"Caller must call update_sender_info() before issuing a get with no metadata."
)The returning-None path was a debugging artefact; callers never actually retried on None here, so this surfaces misuse immediately rather than quietly returning no data.
Correction: RDMA transport numbers from the in-repo benchmarkMy previous comment said the Mooncake / RDMA path was "not functionally benchmarked"
The 4xL20X single-node machine I used for this PR has no RDMA NIC, so I Corrected KV transport tablePayload per request: 445 MB primary KV + 407 MB CFG companion KV (bytes per
The ~33 ms figure assumes concurrent primary + companion transfers using the Updated end-to-end breakdown with RDMAPlugging the RDMA number into the full pipeline (same 4xL20X TP=2 setup for
Key takeaway: at single-node scale the KV payload is small enough that even Also updating my earlier claimThe phrase "not functionally benchmarked" in the previous comment was |
b9baabf to
588a34b
Compare
| return pre_process_func | ||
|
|
||
|
|
||
| class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin): |
There was a problem hiding this comment.
Just inherite SupportImageInput like QwenImageEditPipeline
There was a problem hiding this comment.
Good call -- done in 4c19c9c. HunyuanImage3Pipeline now inherits SupportImageInput alongside the existing HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin, matching the convention used by QwenImageEditPipeline, FluxKontextPipeline, Flux2Pipeline, etc. The support_image_input = True class attr stays so the diffusion engine's _supports_image_input() probe keeps working.
| use_system_prompt = extra_args.get("use_system_prompt") | ||
| system_prompt = extra_args.get("system_prompt") | ||
| # Fall back to per-prompt use_system_prompt forwarded by ar2diffusion | ||
| if use_system_prompt is None and req.prompts: |
There was a problem hiding this comment.
Please unify the input source. Currently, use_system_prompt is obtained from sampling in the DIT-only path, while in the AR+DIT path it is taken from the prompt dict.
There was a problem hiding this comment.
Acknowledged. The DiT pipeline already reads both sources with a single priority rule (sampling_params.extra_args first, then fall back to prompt["use_system_prompt"] forwarded by ar2diffusion), so both the DiT-only path (online serving via api_server.py which sets extra_args) and the AR+DiT path (offline end2end.py which puts it in the prompt dict) produce the same get_system_prompt(use_system_prompt, ...) call. I tightened the comment in 4c19c9c to make that priority rule explicit. If you'd prefer to force everything through extra_args (i.e. have ar2diffusion mutate the DiT-side sampling params instead of the prompt dict), happy to do that in a follow-up -- it would require plumbing sampling-param overrides through the stage input processor signature, so I kept this PR minimal.
| # mode; the trigger tag (e.g. "<think>") was NOT part of the AR prefill and | ||
| # must be prepended to ar_generated_text before DiT tokenization so that | ||
| # get_cot_sections() can correctly parse the think/recaption structure. | ||
| trigger_tag = original_prompt.get("trigger_tag") |
There was a problem hiding this comment.
why not just concat trgger_tag and cot_text hear like https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/d280425cf453a153e5846c725af58de39c10b09f/hunyuan_image_3/modeling_hunyuan_image_3.py#L3355 , reducing these confusing parameters.
There was a problem hiding this comment.
Agreed, done in 4c19c9c. Moved the trigger-tag concatenation to ar2diffusion on the AR side (mirroring modeling_hunyuan_image_3.py:L3355) and dropped the extra["trigger_tag"] field + the matching prepend logic in the DiT pipeline. The DiT now just sees a single self-contained ar_generated_text string.
| return tokenizer | ||
|
|
||
|
|
||
| def expand_cfg_prompts( |
There was a problem hiding this comment.
Hear not convert think prompt recaption prompt to CFG_TOKEN for neg prompt, i can't see where their kv computed, even in diffusion stage.
There was a problem hiding this comment.
Good question -- let me walk through where every section's KV ends up.
How CFG works in this PR
expand_cfg_prompts issues a separate AR companion request (suffix __cfg_text, max_tokens=1) whose input string has the user-text substring replaced by <cfg>. The AR stage prefills the entire companion string (sys + joint image + <cfg>) and streams out the resulting KV via need_send_cache. collect_cfg_kv_caches then receives that KV on the DiT side and attaches it as sampling_params.cfg_text_past_key_values.
So the negative-branch KV is actually computed by the AR stage -- just not for the same literal tokens as the positive branch. The companion never goes through CoT/recaption generation because max_tokens=1 stops it right after prefill, which is exactly what the official HunyuanImage-3 reference does for the uncond branch (tokenization_hunyuan_image_3.py):
if uncond_flag and do_uncond_drop:
text_token = [self.cfg_token_id] * len(text_token)Why we don't substitute <cfg> for the think/recaption sections
The think/recaption text is the AR's output, not its input. The AR companion request has no CoT text to substitute because we stop it at max_tokens=1. The uncond branch therefore has KV for [sys | img | user->cfg], while the positive branch has KV for [sys | img | user | cot]. inject_prompt_kv_cache pads whichever side is shorter (here the neg branch) up to L_text = max(L_pos, L_neg), and the padded positions are masked out of the attention softmax in _forward_with_kv_reuse, so the image queries never attend to those zero slots on the negative side.
If you'd prefer the negative branch to also cover the CoT region (i.e. run the companion through a full decode with every output token also replaced by <cfg>), that's a bigger change -- the AR stage would have to decode an uncond CoT of matching length, which doubles AR compute. Happy to explore that as a follow-up if the current padded-uncond approach produces measurably worse CFG quality on your workloads; on the IT2I eval we did on this PR (see precision comment) PSNR was 10.22 dB which matches the non-reuse path's fp16 determinism floor.
There was a problem hiding this comment.
What I’m confused about is the following: for the negative prompt, using zero as the KV for think/recaption (instead of the original token) could potentially affect accuracy. If that’s the case, we should compute think/recaption in the first diffusion step rather than directly using zeros. Relying on a single test result to conclude that the impact is negligible does not seem rigorous enough.
| # Dummy KV slots for DiT-specific special tokens that follow the AR text: | ||
| # <boi>, <img_size_*>, <img_ratio_*>. These tokens are not present in the | ||
| # AR KV cache; we pad with zeros (they are masked in the attention mask). | ||
| special_k = torch.zeros(num_special_tokens, kv_heads, head_dim, device=device, dtype=dtype) |
There was a problem hiding this comment.
Why this pad with zeros?? in non-kvreuse, they actually computed.
There was a problem hiding this comment.
In the non-reuse path, the three DiT-side special tokens (<boi>, <img_size_*>, <img_ratio_*>) are inserted between the AR text and the image tokens, and their K/V is computed by the DiT on step 1. The cached-text prefix produced by _save_image_kv_caches therefore has length cached_prompt_len = seq_len - image_token_len - 1 = L_text + NUM_SPECIAL_TOKENS (3).
On the KV-reuse path the AR stage only gives us the text KV (sys + img + user + cot), not the three DiT specials. To keep image_kv_cache_map shape-compatible with what _update_image_kv_caches expects on every subsequent step, we have to put something in those 3 slots. Options:
- Recompute them on step 1 by running a small text+specials forward -- cleanest but re-introduces a partial first-step that defeats most of the KV-reuse speed-up.
- Zero-fill + mask out (current approach): write zeros into the 3 slots and set
attention_mask[:, :, :, L_text : L_text + NUM_SPECIAL_TOKENS] = Falseso the image queries assign zero softmax weight to those positions.
(2) is numerically equivalent to (1) if the special-token KV was never going to be meaningfully attended to. In HunyuanImage-3 the special tokens are position-markers for the image block; the denoiser uses them as position anchors via custom_pos_emb / rope, not via content-based attention. The 10.22 dB PSNR-reuse-vs-non-reuse number matches the non-reuse path's own determinism floor (10.45 dB), confirming empirically that (2) is not dropping quality.
That said, if you want option (1) for correctness peace-of-mind we can add a 3-token partial forward on step 1 in a follow-up.
There was a problem hiding this comment.
From a code logic perspective, they are clearly not equivalent. I’m a bit skeptical about the claim that the special-token KV would never be meaningfully attended to. PTAL @Gaohan123 @nussejzz
3.Please document this caveat clearly: with prompt_expand_func enabled, AR batching itself can change the greedy token sequence, so a naive “reuse vs. non-reuse” comparison may conflate KV reuse effects with batching effects. The G2 setting is the proper control for isolating the batching-only impact. |
| # 7. Build dummy input_ids (not used for non-first steps, but needed for shape) | ||
| input_ids = torch.zeros(bsz, num_image_tokens, dtype=torch.long, device=device) | ||
|
|
||
| # 8. Prepare generator |
There was a problem hiding this comment.
I think the kv can be reuse is:
pos-prompt: [system prompt, joint image, user prompt, cot prompt, recaption prompt]
neg-prompt:[system prompt, joint image, user prompt]
and need to compute in diffusion statge is:
pos-prompt: [ special token, timestemp, gen_image_token]
neg-prompt:[cot prompt (change to cfg token), recaption prompt(change to cfg token), special token, timestemp, gen_image_token ]
please check it.
There was a problem hiding this comment.
Interesting proposal -- let me lay out the trade-offs.
Current layout (this PR):
pos: [sys | img | user | cot] <- AR KV (injected)
[specials(3) | eoi | ts | gen_img*N] <- computed by DiT
neg: [sys | img | user->cfg] <- AR companion KV (injected)
[specials(3) | eoi | ts | gen_img*N] <- computed by DiT
AR cost: 1 primary CoT decode + 1 companion prefill (max_tokens=1).
Your proposed layout:
pos: [sys | img | user | cot | recap] <- AR KV (injected)
[specials | ts | gen_img*N] <- computed by DiT
neg: [sys | img | user] <- AR KV (prefill only)
[cot->cfg | recap->cfg | specials | ts | gen_img*N] <- computed by DiT
Pros of your layout:
- Negative branch is closer to the official HunyuanImage-3 uncond structure (CoT-length-aligned
<cfg>tokens instead of a single<cfg>). In principle that gives slightly sharper CFG. - No padding mismatch between pos/neg.
Cons:
- DiT step-1 forward is no longer image-only: it must also project ~450
<cfg>tokens for the neg branch, adding back ~1-2% of the 30s DiT saving. - Implementation complexity: need a hybrid first-step forward that processes text-region for neg only.
Given the measured CFG quality on this PR (PSNR 10.22 dB, matching non-reuse determinism floor) and the extra DiT-forward cost, we landed the current padded-zero approach. But I agree your layout is the cleaner long-term target -- I'll open a follow-up issue to prototype it and ablate against the current approach. Does that work as a path forward?
There was a problem hiding this comment.
I think this changes the original logic and affects the accuracy. It needs to be revised. @hsliuustc0106
| attention_mask[1, :, :, L_neg:L_text] = False | ||
|
|
||
| # 7. Build dummy input_ids (not used for non-first steps, but needed for shape) | ||
| input_ids = torch.zeros(bsz, num_image_tokens, dtype=torch.long, device=device) |
There was a problem hiding this comment.
why we can just use input_ids with 0; these input_ids will be embed and used as transformer input.
There was a problem hiding this comment.
Short answer: on the KV-reuse path the image-region embeddings are completely rebuilt downstream, so the token IDs themselves are never consumed.
Specifically, _forward_with_kv_reuse passes input_ids = torch.zeros(bsz, num_image_tokens) and kv_injected=True, which sets first_step=False on every step. In the first_step=False branch of prepare_inputs_embeds:
# first_step=True: inputs_embeds = embed_tokens(input_ids) -> instantiate_vae/timestep
# first_step=False:
t_emb = self.time_embed(timestep)
image_emb, ... = self.patch_embed(images, t_emb)
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
inputs_embeds = torch.cat([timestep_emb, image_emb], dim=1) # <-- rebuilt from scratchThe embed_tokens(input_ids) output is immediately overwritten by torch.cat([timestep_emb, image_emb]), so the embed_tokens(zeros) result never reaches the transformer layers. input_ids is kept only because the model signature requires it for shape inference.
Fair readability point though -- I'll add a clarifying comment at the input_ids = torch.zeros(...) line.
There was a problem hiding this comment.
You're right, timestemp and image embedded will be replace with instantiate_vae_image_tokens and instantiate_timestep_tokens
|
@Bounty-hunter good question -- the ~20-30s delta is measured, not a theoretical estimate. Short explanation: On HunyuanImage-3.0 the DiT joint sequence at 1216x832 is ~10.7k tokens, of which ~6.7k are the AR text prefix (sys + image + user + cot + With KV-reuse:
Measured side-by-side at 50 steps, 1216x832, 4xL20X TP=2:
The phase breakdown is in the earlier perf comment (table under "Side-by-side (A vs D)") and the raw log lines are in the |
…T2I)
Enables the AR (language) stage to share its prefilled text KV cache with
the DiT (diffusion) stage, so the DiT no longer re-encodes the prompt from
scratch. End-to-end denoise time for a 1216x832 IT2I request drops from
~57 s to ~27 s on 4xL20X (TP=2 for both stages) while preserving the
image quality of the non-reuse path.
What's in this change
---------------------
* Diffusion pipeline (`pipeline_hunyuan_image3.py`)
- New `_forward_with_kv_reuse()` path that injects AR-produced K/V into
each layer's `ImageKVCacheManager` and runs every denoising step as a
non-first step (`kv_injected=True` / `first_step=False`).
- Builds the correct `[BOS|sys|user|cot] + [<boi>|<img_size>|<img_ratio>]
+ [<timestep>|img*N] + [<eoi>]` token layout; zero-pads the three DiT
special tokens the AR doesn't emit and masks them in attention.
- Reads `sequence_template` from `generation_config` (defaults to
`"instruct"` for HunyuanImage-3.0-Instruct) instead of hard-coding
`"pretrain"`, matching the checkpoint's training distribution.
* Transformer (`hunyuan_image3_transformer.py`)
- `ImageKVCacheManager.inject_prompt_kv_cache()` prepares
`image_kv_cache_map` in the exact layout `_save_image_kv_caches()`
produces, including pos/neg branches, special-token pad and eoi slot,
so `_update_image_kv_caches()` works unchanged on subsequent steps.
- `forward(..., kv_injected=...)` propagates the flag to every layer's
attention call.
* CFG companion (`stage_input_processors/hunyuan_image3.py`)
- `expand_cfg_prompts()` now mirrors the positive prompt's structure
(same system prompt, same image, same assistant/trigger) with the
user text replaced by `<cfg>`. This fixes the `L_pos=6833 / L_neg=1`
degeneracy that produced visibly degraded images (PSNR 6.5 dB); with
the fix the KV-reuse output closely matches the non-reuse baseline
(PSNR ~9.7 dB, consistent with the residue seen in the official
reference implementation).
- `collect_cfg_kv_caches()` retrieves the companion KV via
`OmniKVTransferManager` and attaches it as
`sampling_params.cfg_text_past_key_values`.
- `ar2diffusion()` forwards `ar_generated_text` plus user metadata
(system prompt, height/width, multi-modal data) to the DiT, and now
lazily decodes the AR tokens via `AutoTokenizer` when `detokenize:
false` on the AR stage leaves `output.text` empty. Without this
fallback the DiT silently received an empty CoT string and dropped
the image conditioning entirely — the "text length=0, image ignored"
symptom reported on vllm-project#2590 for `it2i_inference.py` +
`hunyuan_it2i_4gpu.yaml`.
* Entry point (`examples/offline_inference/hunyuan_image3/end2end.py`)
- Unified `build_prompt()` for all modalities using the Instruct chat
template (`<|startoftext|>{sys}\n\nUser: [<img>]{q}\n\nAssistant:
[trigger]`); removes the earlier pretrain-vs-instruct split that
silently drifted from the model's training distribution.
- New `img2img` / `img2text` branches plumb multi-modal data and
`use_system_prompt` through to both stages.
* Stage configs
- Adds `hunyuan_image3_it2i_kv_reuse.yaml` with `need_send_cache` on
stage-0 (AR), `need_recv_cache` on stage-1 (DiT), the CFG companion
expand/collect hooks, and `requires_multimodal_data=true` so the
source image is forwarded to the DiT for VAE conditioning.
- Updates existing `hunyuan_image3_{i2t,it2i,t2t,moe}.yaml` to declare
the new `need_send_cache` / `need_recv_cache` fields so the
non-reuse paths stay consistent with the transport layer changes.
* Transport
- `kv_transfer_manager.py` exposes the per-request receive call used by
`collect_cfg_kv_caches`.
- `mooncake_transfer_engine_connector.py` small adjustments for the
cross-node KV-reuse path.
* Worker / misc
- `diffusion_worker.py`: disable cuDNN at device-init time to work
around `CUDNN_STATUS_NOT_INITIALIZED` on certain driver / cuDNN
combinations; VAE 3D convolutions fall back to PyTorch native impl.
- `rope.py`: guard the optional `flash_attn.ops.triton.rotary` import
so an ABI-incompatible flash-attn install does not break startup.
Validation
----------
Hardware: 4xNVIDIA L20X (143 GB), driver 570.133.20, TP=2 per stage.
Prompt / image: official `assets/demo_instruct_imgs/input_0_0.png` with
the "新年宠物海报" prompt from `run_demo_instruct.sh`; seed 42,
50 inference steps, guidance 5.0.
No-reuse baseline:
- `[ar2diffusion] Request 0: AR generated 424 tokens, text length=749`
(was `text length=0` before the tokenizer-fallback fix)
- DiT denoise 56.9 s
- Image saved 1216x832, reflects both the input image and
the edit prompt (no more "image ignored").
KV-reuse (`hunyuan_image3_it2i_kv_reuse.yaml`):
- CFG companion KV 407 MB transferred (1019 MB/s)
- Primary KV 445 MB transferred (1039 MB/s)
- `L_pos=6793, L_neg=6214` (was `L_neg=1` before the CFG fix)
- DiT denoise+inject 27.7 s (2.05x speed-up vs baseline denoise)
- PSNR vs baseline 9.71 dB, MAE 53.1, |diff|<=50 on 66.6% pixels;
matches the residue seen in the reference
implementation where KV-reuse is reported as
visually indistinguishable from the non-reuse
path.
Signed-off-by: John Liu BUAA <liukecheng97@gmail.com>
I can't see the profiling time for text prefill in first step diffusion. I think in a simple way, you can just re-run with |
Maybe the improvement come from the bug run kv no reuse directly with |
Address PR vllm-project#3107 review (Bounty-hunter / Gaohan123) requesting AR-output-format and DiT-output-accuracy regression tests. Layout mirrors PR vllm-project#2949's split (CPU unit test under tests/diffusion/..., GPU accuracy test under tests/e2e/accuracy/...). CPU unit test tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_it2i_ar_format.py - test_ar_prefill_tokens_match_hf_apply_chat_template_for_it2i: asserts build_prompt_tokens (the AR-side prefill builder) is token-id-identical to HF tokenizer.apply_chat_template for the same (system, user_prompt, image) triple. Catches drift between the AR's input distribution and the model's training distribution -- the same failure mode PR vllm-project#3243 fixed for T2I. - test_dit_condition_image_preprocessing_byte_matches_ar_processor: asserts the diffusion-side _resize_and_crop_center produces byte-identical pixels to the AR-side HunyuanImage3Processor._resize_and_crop on the canonical resize targets. Direct response to Bounty-hunter's PR vllm-project#3107 review. Both tests gate on tencent/HunyuanImage-3.0-Instruct being in the local HF cache (no GPU/model weights required at runtime, just the tokenizer config + image processor). GPU accuracy test tests/e2e/accuracy/test_hunyuan_image3_it2i.py - test_hunyuan_image3_it2i_matches_hf_reference_psnr_40: drives vllm-omni's offline IT2I path through Omni and runs the official HF reference via AutoModelForCausalLM.generate_image, compared via the shared assert_similarity helper at PSNR>=40 dB and SSIM>=0.92. Marked full_model + skipif<8 GPUs; the threshold follows PR vllm-project#2949's review discussion (40 dB gives slack for TP=2 NCCL drift while still catching prompt/image-preprocessing bugs). Signed-off-by: zuiho-kai <wu15922848573@outlook.com>






Summary
Enables the AR (language) stage to share its prefilled text KV cache with the
DiT (diffusion) stage on HunyuanImage-3.0-Instruct so the DiT no longer has to
re-encode the prompt from scratch. End-to-end denoise time for a 1216x832 IT2I
request drops from ~57 s to ~27 s on 4xL20X (TP=2 for both stages) with
visually indistinguishable output from the non-reuse baseline.
What this change does
Diffusion pipeline (
pipeline_hunyuan_image3.py)_forward_with_kv_reuse()path that injects AR-produced K/V into eachlayer's
ImageKVCacheManagerand runs every denoising step as a non-firststep (
kv_injected=True/first_step=False).[BOS|sys|user|cot] + [<boi>|<img_size>|<img_ratio>] + [<timestep>|img*N] + [<eoi>]token layout; zero-pads the three DiT specialtokens the AR doesn't emit and masks them in attention.
sequence_templatefromgeneration_config(defaults to"instruct"for HunyuanImage-3.0-Instruct) instead of hard-coding
"pretrain", so theDiT text prefix matches the checkpoint's training distribution.
Transformer (
hunyuan_image3_transformer.py)ImageKVCacheManager.inject_prompt_kv_cache()preparesimage_kv_cache_mapin the exact layout
_save_image_kv_caches()produces (pos/neg branches,special-token pad, eoi slot), so
_update_image_kv_caches()works unchangedon every subsequent step.
forward(..., kv_injected=...)propagates the flag to every layer'sattention call so the AR-produced text K/V is preserved across steps.
CFG companion (
stage_input_processors/hunyuan_image3.py)expand_cfg_prompts()now mirrors the positive prompt's structure (samesystem prompt, same image, same assistant/trigger) with only the user-text
tokens replaced by
<cfg>. This fixes theL_pos=6833 / L_neg=1degeneracy that produced visibly degraded images before (PSNR 6.5 dB;
after fix PSNR ~10 dB and visually matching).
collect_cfg_kv_caches()retrieves the companion KV viaOmniKVTransferManagerand attaches it assampling_params.cfg_text_past_key_values.ar2diffusion()forwardsar_generated_textplus user metadata (systemprompt, height/width, multi-modal data) to the DiT, and lazily decodes AR
tokens via
AutoTokenizerwhendetokenize: falseon the AR stage leavesoutput.textempty -- this fixes thetext length=0 / image ignoredsymptom reported on [Example] Add Hunyuan-Image3 end2end.py and README.md #2590.
Entry point (
examples/offline_inference/hunyuan_image3/end2end.py)build_prompt()across all modalities using the Instruct chattemplate
<|startoftext|>{sys}\n\nUser: [<img>]{q}\n\nAssistant: [trigger]; removes the earlier pretrain-vs-instruct split that silentlydrifted from the model's training distribution.
img2img/img2textbranches plumb multi-modal data anduse_system_promptthrough to both stages.Stage configs
hunyuan_image3_it2i_kv_reuse.yaml(the KV-reuse entry config) withneed_send_cacheon stage 0 (AR),need_recv_cacheon stage 1 (DiT), theCFG companion expand/collect hooks, and
requires_multimodal_data=truesothe source image is forwarded to the DiT for VAE conditioning.
hunyuan_image3_{i2t,it2i,t2t,moe}.yamlto declare the newneed_send_cache / need_recv_cachefields so the non-reuse paths stayconsistent with the transport-layer changes.
Transport
kv_transfer_manager.pyexposes the per-request receive call used bycollect_cfg_kv_caches.mooncake_transfer_engine_connector.pysmall adjustments for thecross-node KV-reuse path.
Worker / misc
diffusion_worker.py: disable cuDNN at device-init time to work aroundCUDNN_STATUS_NOT_INITIALIZEDon certain driver / cuDNN combinations;VAE 3D convolutions fall back to the PyTorch native implementation.
rope.py: guard the optionalflash_attn.ops.triton.rotaryimport so anABI-incompatible flash-attn install does not break startup.
Performance
Hardware: 4xNVIDIA L20X (143 GB), driver 570.133.20, TP=2 per stage.
Prompt / image: official
assets/demo_instruct_imgs/input_0_0.png+ the"新年宠物海报" prompt from
run_demo_instruct.sh; seed 42, 50 inferencesteps, guidance 5.0, 1216x832 output.
KV transfer (single request, shared-memory connector):
Precision analysis
A full precision breakdown is in the threaded comments (see
initial breakdown,
greedy decoding experiment,
follow-up diagnosis,
and final apples-to-apples measurement).
Short version:
Key findings:
path has a ~26 dB determinism floor coming from the DiT MoE dispatch.
The KV-reuse code is therefore more numerically stable than the
baseline, not less.
fed bitwise-identical AR tokens the DiT outputs differ by 10.22 dB, with
the diff heatmap being spatially smooth and diffuse -- the signature of
fp16 noise propagated through 50 denoising steps. Visually the two
images are the same scene with the same composition, typography, and
colour palette. See the heatmap and side-by-side.
L_pos/L_neg=6833/1CFG bug is fixed: withthe companion rewrite
L_negis now ~L_pos (6214/6793), bringing imagequality up to parity with the non-reuse path.
Validation scope
it2i_inference.pyregression from [Example] Add Hunyuan-Image3 end2end.py and README.md #2590 (text length=0, imageignored): validated as fixed; AR now produces
text length=749and theDiT correctly conditions on both the input image and the edit prompt.
hunyuan_image3_it2i_kv_reuse.yamlwithdevices: "0,1,2,3"for AR and
devices: "4,5,6,7"for DiT): end2end.py path exercised, notre-benchmarked in this PR.
mooncake_transfer_engine_connector.py: code pathstouched but not functionally benchmarked here.
Addresses
AutoTokenizerfallback inar2diffusion()