[Feature] [HunyuanImage3] Use FlashAttention for image denoising#1975
[Feature] [HunyuanImage3] Use FlashAttention for image denoising#1975nussejzz wants to merge 4 commits into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 5353b9e686
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
Gaohan123
left a comment
There was a problem hiding this comment.
Please supplement performance test results
OK~ |
1bd0d0f to
d254f88
Compare
|
why it only helps improve a little? |
From my point of view, it is because the 1.06x is end-to-end (including MoE FFN with 64 experts, VAE decode, etc.), not attention-only. Attention is a small fraction of per-step compute |
so, can you profile it? |
Sure, I'll do further analysis and testing once the foreign exchange arrives tomorrow.😊❤️ |
lishunyang12
left a comment
There was a problem hiding this comment.
Implementation looks correct. Two things:
-
Use
_unwrap_flash_outputhelper instead of the inlineisinstance(attn_output, tuple)check — the codebase already has this utility. -
Test results (latency + numerical accuracy vs SDPA path) should be provided before merge.
9f2caf8 to
01d4b18
Compare
e607efe to
dcebc95
Compare
princepride
left a comment
There was a problem hiding this comment.
LGTM, but I want know why you changed the yaml file?
princepride
left a comment
There was a problem hiding this comment.
We need consider timestamp token's attention😂
No problem |
8f767bb to
ba1de13
Compare
0f315a8 to
9248468
Compare
|
Can you profile it using torch profiler ? Userful links: https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/profiling/ |
There was a problem hiding this comment.
Pull request overview
This PR updates HunyuanImage-3 denoising attention to use FlashAttention where possible, reducing memory usage by avoiding GQA KV expansion (repeat_kv) and removing the unused EOI token from KV caches.
Changes:
- Add FlashAttention-based multi-path dispatch for step 1 vs steps 2–50, and non-SP vs SP execution.
- Remove EOI from image KV caching and adjust masks/kv lengths accordingly.
- Stop forcing
DIFFUSION_ATTENTION_BACKEND=TORCH_SDPAin the pipeline constructor.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py | Removes global env forcing of SDPA backend so the model can select FA paths automatically. |
| vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py | Implements FlashAttention dispatch paths, packed varlen FA, SP timestamp fix, and removes EOI KV caching. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=q_len - 1, | ||
| max_seqlen_k=kv_len, | ||
| causal=False, |
There was a problem hiding this comment.
The varlen FlashAttention call doesn’t pass softmax_scale=self.scaling, while the other FlashAttention branches in this function do. If self.scaling differs from FlashAttention’s default scaling, this will silently change model numerics vs the other FA paths and vs the original masked implementation. Pass softmax_scale=self.scaling (if supported by your _flash_attn_varlen_func wrapper), or otherwise ensure/document that self.scaling is exactly the FA default for this model.
| causal=False, | |
| causal=False, | |
| softmax_scale=self.scaling, |
| outputs = [] | ||
| for b in range(bs): | ||
| q_packed = query[b] # (q_len, H, D) | ||
| k_packed = torch.cat([key[b, :ts_kv_len], key[b]], dim=0) | ||
| v_packed = torch.cat([value[b, :ts_kv_len], value[b]], dim=0) | ||
|
|
||
| cu_seqlens_q = torch.tensor([0, 1, q_len], dtype=torch.int32, device=query.device) | ||
| cu_seqlens_k = torch.tensor( | ||
| [0, ts_kv_len, ts_kv_len + kv_len], | ||
| dtype=torch.int32, | ||
| device=query.device, | ||
| ) | ||
|
|
There was a problem hiding this comment.
The ‘best path’ varlen implementation runs a Python loop over bs, repeatedly allocates cu_seqlens_* tensors, and concatenates k_packed/v_packed per batch. This adds avoidable overhead and extra memory traffic, partially offsetting the intended performance/memory gains. Consider packing across the entire batch in a single varlen call (2*bs sub-sequences) and precomputing cu_seqlens_q/cu_seqlens_k once, or at minimum hoist the cu_seqlens_* construction out of the for b in range(bs) loop.
| outputs = [] | |
| for b in range(bs): | |
| q_packed = query[b] # (q_len, H, D) | |
| k_packed = torch.cat([key[b, :ts_kv_len], key[b]], dim=0) | |
| v_packed = torch.cat([value[b, :ts_kv_len], value[b]], dim=0) | |
| cu_seqlens_q = torch.tensor([0, 1, q_len], dtype=torch.int32, device=query.device) | |
| cu_seqlens_k = torch.tensor( | |
| [0, ts_kv_len, ts_kv_len + kv_len], | |
| dtype=torch.int32, | |
| device=query.device, | |
| ) | |
| cu_seqlens_q = torch.tensor( | |
| [0, 1, q_len], | |
| dtype=torch.int32, | |
| device=query.device, | |
| ) | |
| cu_seqlens_k = torch.tensor( | |
| [0, ts_kv_len, ts_kv_len + kv_len], | |
| dtype=torch.int32, | |
| device=query.device, | |
| ) | |
| outputs = [] | |
| for b in range(bs): | |
| q_packed = query[b] # (q_len, H, D) | |
| k_packed = torch.cat([key[b, :ts_kv_len], key[b]], dim=0) | |
| v_packed = torch.cat([value[b, :ts_kv_len], value[b]], dim=0) |
| # Padding tokens (<sp_size) participate without mask; impact is | ||
| # negligible (<0.2% attention mass) and outputs are discarded. |
There was a problem hiding this comment.
In SP mode you explicitly allow padding tokens to participate in attention without masking. Even if padded token outputs are discarded, their presence still changes the attention distribution for non-padding tokens (i.e., it is not a pure no-op), which can introduce quality regressions or nondeterministic drift depending on how padding embeddings/activations look. If exactness vs the masked baseline is a goal, consider a mask-free equivalent (e.g., varlen attention with per-rank effective lengths, or a packing strategy similar to the non-SP varlen path), or add a clearly documented accuracy/acceptability threshold plus a regression test/metric guard for SP runs.
| if not first_step and _flash_ok and self.sp_size <= 1 and _has_varlen: | ||
| # Steps 2-50, best path: packed varlen, one kernel, zero mask, exact. | ||
| # Q = [timestamp(1) | image(q_len-1)] | ||
| # timestamp attends to text+ts only; image attends to text+ts+image. | ||
| # KV already excludes eoi (removed from cache). | ||
| # |
There was a problem hiding this comment.
The attention dispatch has grown into a long multi-branch conditional with subtle per-path invariants (e.g., whether KV includes EOI, whether masks are trimmed, which tokens are present in query). This increases the risk of future changes breaking a single path. A tangible improvement would be to extract each major path into a small helper (e.g., _attn_steps_2_50_nonsp_varlen, _attn_steps_2_50_sp_fix_ts, _attn_step1_nonsp_split, etc.) and centralize shared post-processing (reshape, _unwrap_flash_output) to reduce duplication and make invariants easier to audit.
hsliuustc0106
left a comment
There was a problem hiding this comment.
🔍 PR #1975 Review: FlashAttention for HunyuanImage3 Denoising
✅ Passed Items
- Correctness, performance, implementation quality all good
- 6-path dispatch logic is correct
- EOI handling is reasonable
⚠️ Items to Add
1. Functional Correctness Tests (Required)
- Verify numerical consistency between FlashAttention and SDPA paths
- Cover all 6 dispatch paths
2. E2E Tests (Strongly Recommended)
- Verify actual image generation quality
- Compare outputs with fixed seed
📝 Recommendations
Before merging, please add:
# 1. Functional correctness tests
tests/diffusion/models/hunyuan_image_3/test_flash_attention.py- Verify numerical consistency across all 6 dispatch paths with SDPA baseline
# 2. E2E tests
tests/e2e/test_hunyuan_image3_flash_attn.py- Verify actual image generation quality
- Compare Flash and SDPA outputs with fixed seed
Code quality is excellent, implementation logic is correct. Ready to merge after adding tests. 👍
61a061f to
84db394
Compare
Replace SDPA with a packed varlen FlashAttention path for HunyuanImage3 image denoising iterations 2-50, with SP-aware dispatch and formatting cleanup. Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
…ntion Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
The eoi token is never attended to in denoising steps 2-50, so drop it from the KV cache to save memory and simplify the attention path. Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
…tion dispatch paths Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
84db394 to
0564715
Compare

Summary
repeat_kvattention with split FlashAttention for HunyuanImage3 denoising, eliminating GQA KV expansion (4x memory saving) and mask overheadragged_final_layerviamasked_select(image_mask). Saves cache memory and eliminates dead computation in step 1Key Design Decisions
image_maskmarks only image token positions as True; eoi position is False.ragged_final_layerusesmasked_select(image_mask)(step 1) orx[:, 1:](steps 2+) to extract only image tokens. eoi computation in step 1 Path ④ replaced withtorch.zeros_like.Test Plan
Test Result
Theoretical test value
Actual test value
Parameter Explanation
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)