Skip to content

[Feature] [HunyuanImage3] Use FlashAttention for image denoising#1975

Open
nussejzz wants to merge 4 commits into
vllm-project:mainfrom
nussejzz:feat/hunyuan-image3-flash-attn
Open

[Feature] [HunyuanImage3] Use FlashAttention for image denoising#1975
nussejzz wants to merge 4 commits into
vllm-project:mainfrom
nussejzz:feat/hunyuan-image3-flash-attn

Conversation

@nussejzz
Copy link
Copy Markdown
Contributor

@nussejzz nussejzz commented Mar 18, 2026

Summary

  • Replace 4D mask + repeat_kv attention with split FlashAttention for HunyuanImage3 denoising, eliminating GQA KV expansion (4x memory saving) and mask overhead
  • Six-path dispatch covering step 1 / steps 2-50 × non-SP / SP / fallback:
    • Steps 2-50, non-SP: packed varlen FA (best) or two FA calls (fallback) — timestamp→text+ts, image→all
    • Steps 2-50, SP: maskless Ulysses SP Attention + local timestamp fix on rank 0
    • Step 1, non-SP: causal(text+ts) + full(image→text+ts+image), eoi zeroed out
    • Step 1, SP: text causal (local) + image SP (text as joint KV) + timestamp fix on rank 0
    • Fallback: original 4D mask path (fp32 or no flash_attn)
  • Remove eoi token from KV cache — no token ever attends to eoi's KV (mask column all-False), and eoi's hidden state is discarded by ragged_final_layer via masked_select(image_mask). Saves cache memory and eliminates dead computation in step 1
  • Path selection is fully automatic based on dtype (fp16/bf16), flash_attn availability, and sp_size. Compatible with TP.

Key Design Decisions

  • Maskless SP + local fix: instead of passing 4D mask through Ulysses AllToAll (impossible without repeat_kv), run full attention in SP then fix timestamp output on rank 0 with a single Q_len=1 FA call. Zero extra memory, negligible overhead.
  • eoi is a ghost token: verified in both original HunyuanImage-3.0 code and vllm-omni — image_mask marks only image token positions as True; eoi position is False. ragged_final_layer uses masked_select(image_mask) (step 1) or x[:, 1:] (steps 2+) to extract only image tokens. eoi computation in step 1 Path ④ replaced with torch.zeros_like.

Test Plan

python3 benchmarks/bench_hunyuan_flash_attn.py \
  --model /data/HunyuanImage-3.0 \
  --stage-configs-path vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml \
  --prompt "A brown and white dog is running on the grass" \
  --output-prefix output_bench \
  2>&1 | tee bench_3way.log

Test Result

image

Theoretical test value

image

Actual test value

image

Parameter Explanation

Stage varlen flash sdpa
prep/reshape — input tensor transform reshape: 0.085ms prep: 0.050ms transpose×3 + .contiguous(): 0.053ms
repeat_kv — KV heads expansion (8→32) None None 0.033ms
kernel — attention computation 0.535ms 0.555ms 1.152ms
post — output transform None reshape: 0.014ms transpose + reshape: 0.015ms
total 0.622ms 0.619ms 1.253ms

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@nussejzz nussejzz requested a review from hsliuustc0106 as a code owner March 18, 2026 06:03
Comment thread vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5353b9e686

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py Outdated
Comment thread vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py Outdated
Copy link
Copy Markdown
Collaborator

@Gaohan123 Gaohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please supplement performance test results

@nussejzz
Copy link
Copy Markdown
Contributor Author

Please supplement performance test results

OK~

@nussejzz nussejzz force-pushed the feat/hunyuan-image3-flash-attn branch 4 times, most recently from 1bd0d0f to d254f88 Compare March 21, 2026 12:46
@nussejzz nussejzz requested review from Gaohan123 and gcanlin March 21, 2026 12:50
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

why it only helps improve a little?

@nussejzz
Copy link
Copy Markdown
Contributor Author

why it only helps improve a little?

From my point of view, it is because the 1.06x is end-to-end (including MoE FFN with 64 experts, VAE decode, etc.), not attention-only. Attention is a small fraction of per-step compute

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

why it only helps improve a little?

From my point of view, it is because the 1.06x is end-to-end (including MoE FFN with 64 experts, VAE decode, etc.), not attention-only. Attention is a small fraction of per-step compute

so, can you profile it?

@nussejzz
Copy link
Copy Markdown
Contributor Author

why it only helps improve a little?

From my point of view, it is because the 1.06x is end-to-end (including MoE FFN with 64 experts, VAE decode, etc.), not attention-only. Attention is a small fraction of per-step compute

so, can you profile it?

Sure, I'll do further analysis and testing once the foreign exchange arrives tomorrow.😊❤️

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks correct. Two things:

  1. Use _unwrap_flash_output helper instead of the inline isinstance(attn_output, tuple) check — the codebase already has this utility.

  2. Test results (latency + numerical accuracy vs SDPA path) should be provided before merge.

@nussejzz nussejzz force-pushed the feat/hunyuan-image3-flash-attn branch 4 times, most recently from 9f2caf8 to 01d4b18 Compare March 22, 2026 06:29
@nussejzz nussejzz requested a review from lishunyang12 March 22, 2026 06:32
@nussejzz
Copy link
Copy Markdown
Contributor Author

nussejzz commented Mar 22, 2026

@nussejzz nussejzz force-pushed the feat/hunyuan-image3-flash-attn branch 2 times, most recently from e607efe to dcebc95 Compare March 24, 2026 13:37
Copy link
Copy Markdown
Collaborator

@princepride princepride left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I want know why you changed the yaml file?

Comment thread vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml
Comment thread vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml Outdated
Comment thread vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml Outdated
@nussejzz nussejzz requested a review from princepride March 24, 2026 14:45
@princepride princepride removed the ready label to trigger buildkite CI label Mar 25, 2026
Copy link
Copy Markdown
Collaborator

@princepride princepride left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need consider timestamp token's attention😂

@nussejzz
Copy link
Copy Markdown
Contributor Author

We need consider timestamp token's attention😂

No problem

@nussejzz nussejzz force-pushed the feat/hunyuan-image3-flash-attn branch 4 times, most recently from 8f767bb to ba1de13 Compare March 30, 2026 02:47
@nussejzz nussejzz changed the title [Feature] [HunyuanImage3] Use FlashAttention for image denoising steps 2-50 [Feature] [HunyuanImage3] Use FlashAttention for image denoising Mar 30, 2026
@lishunyang12
Copy link
Copy Markdown
Collaborator

Can you profile it using torch profiler ? Userful links: https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/profiling/

@hsliuustc0106 hsliuustc0106 requested review from Copilot and removed request for Bounty-hunter April 2, 2026 06:39
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates HunyuanImage-3 denoising attention to use FlashAttention where possible, reducing memory usage by avoiding GQA KV expansion (repeat_kv) and removing the unused EOI token from KV caches.

Changes:

  • Add FlashAttention-based multi-path dispatch for step 1 vs steps 2–50, and non-SP vs SP execution.
  • Remove EOI from image KV caching and adjust masks/kv lengths accordingly.
  • Stop forcing DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA in the pipeline constructor.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py Removes global env forcing of SDPA backend so the model can select FA paths automatically.
vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py Implements FlashAttention dispatch paths, packed varlen FA, SP timestamp fix, and removes EOI KV caching.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len - 1,
max_seqlen_k=kv_len,
causal=False,
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The varlen FlashAttention call doesn’t pass softmax_scale=self.scaling, while the other FlashAttention branches in this function do. If self.scaling differs from FlashAttention’s default scaling, this will silently change model numerics vs the other FA paths and vs the original masked implementation. Pass softmax_scale=self.scaling (if supported by your _flash_attn_varlen_func wrapper), or otherwise ensure/document that self.scaling is exactly the FA default for this model.

Suggested change
causal=False,
causal=False,
softmax_scale=self.scaling,

Copilot uses AI. Check for mistakes.
Comment on lines +1078 to 1090
outputs = []
for b in range(bs):
q_packed = query[b] # (q_len, H, D)
k_packed = torch.cat([key[b, :ts_kv_len], key[b]], dim=0)
v_packed = torch.cat([value[b, :ts_kv_len], value[b]], dim=0)

cu_seqlens_q = torch.tensor([0, 1, q_len], dtype=torch.int32, device=query.device)
cu_seqlens_k = torch.tensor(
[0, ts_kv_len, ts_kv_len + kv_len],
dtype=torch.int32,
device=query.device,
)

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ‘best path’ varlen implementation runs a Python loop over bs, repeatedly allocates cu_seqlens_* tensors, and concatenates k_packed/v_packed per batch. This adds avoidable overhead and extra memory traffic, partially offsetting the intended performance/memory gains. Consider packing across the entire batch in a single varlen call (2*bs sub-sequences) and precomputing cu_seqlens_q/cu_seqlens_k once, or at minimum hoist the cu_seqlens_* construction out of the for b in range(bs) loop.

Suggested change
outputs = []
for b in range(bs):
q_packed = query[b] # (q_len, H, D)
k_packed = torch.cat([key[b, :ts_kv_len], key[b]], dim=0)
v_packed = torch.cat([value[b, :ts_kv_len], value[b]], dim=0)
cu_seqlens_q = torch.tensor([0, 1, q_len], dtype=torch.int32, device=query.device)
cu_seqlens_k = torch.tensor(
[0, ts_kv_len, ts_kv_len + kv_len],
dtype=torch.int32,
device=query.device,
)
cu_seqlens_q = torch.tensor(
[0, 1, q_len],
dtype=torch.int32,
device=query.device,
)
cu_seqlens_k = torch.tensor(
[0, ts_kv_len, ts_kv_len + kv_len],
dtype=torch.int32,
device=query.device,
)
outputs = []
for b in range(bs):
q_packed = query[b] # (q_len, H, D)
k_packed = torch.cat([key[b, :ts_kv_len], key[b]], dim=0)
v_packed = torch.cat([value[b, :ts_kv_len], value[b]], dim=0)

Copilot uses AI. Check for mistakes.
Comment on lines +1127 to +1128
# Padding tokens (<sp_size) participate without mask; impact is
# negligible (<0.2% attention mass) and outputs are discarded.
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In SP mode you explicitly allow padding tokens to participate in attention without masking. Even if padded token outputs are discarded, their presence still changes the attention distribution for non-padding tokens (i.e., it is not a pure no-op), which can introduce quality regressions or nondeterministic drift depending on how padding embeddings/activations look. If exactness vs the masked baseline is a goal, consider a mask-free equivalent (e.g., varlen attention with per-rank effective lengths, or a packing strategy similar to the non-SP varlen path), or add a clearly documented accuracy/acceptability threshold plus a regression test/metric guard for SP runs.

Copilot uses AI. Check for mistakes.
Comment on lines +1065 to +1070
if not first_step and _flash_ok and self.sp_size <= 1 and _has_varlen:
# Steps 2-50, best path: packed varlen, one kernel, zero mask, exact.
# Q = [timestamp(1) | image(q_len-1)]
# timestamp attends to text+ts only; image attends to text+ts+image.
# KV already excludes eoi (removed from cache).
#
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention dispatch has grown into a long multi-branch conditional with subtle per-path invariants (e.g., whether KV includes EOI, whether masks are trimmed, which tokens are present in query). This increases the risk of future changes breaking a single path. A tangible improvement would be to extract each major path into a small helper (e.g., _attn_steps_2_50_nonsp_varlen, _attn_steps_2_50_sp_fix_ts, _attn_step1_nonsp_split, etc.) and centralize shared post-processing (reshape, _unwrap_flash_output) to reduce duplication and make invariants easier to audit.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔍 PR #1975 Review: FlashAttention for HunyuanImage3 Denoising

Passed Items

  • Correctness, performance, implementation quality all good
  • 6-path dispatch logic is correct
  • EOI handling is reasonable

⚠️ Items to Add

1. Functional Correctness Tests (Required)

  • Verify numerical consistency between FlashAttention and SDPA paths
  • Cover all 6 dispatch paths

2. E2E Tests (Strongly Recommended)

  • Verify actual image generation quality
  • Compare outputs with fixed seed

📝 Recommendations

Before merging, please add:

# 1. Functional correctness tests
tests/diffusion/models/hunyuan_image_3/test_flash_attention.py
  • Verify numerical consistency across all 6 dispatch paths with SDPA baseline
# 2. E2E tests
tests/e2e/test_hunyuan_image3_flash_attn.py
  • Verify actual image generation quality
  • Compare Flash and SDPA outputs with fixed seed

Code quality is excellent, implementation logic is correct. Ready to merge after adding tests. 👍

@nussejzz
Copy link
Copy Markdown
Contributor Author

Functional Correctness Test

Clipboard_Screenshot_1776085450

@nussejzz nussejzz force-pushed the feat/hunyuan-image3-flash-attn branch 2 times, most recently from 61a061f to 84db394 Compare April 13, 2026 14:08
Replace SDPA with a packed varlen FlashAttention path for HunyuanImage3
image denoising iterations 2-50, with SP-aware dispatch and formatting
cleanup.

Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
…ntion

Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
The eoi token is never attended to in denoising steps 2-50, so drop it
from the KV cache to save memory and simplify the attention path.

Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
…tion dispatch paths

Signed-off-by: Ding Zuhao <e1583181@u.nus.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants