[Bugfix/Feature] Remove Hardcoded Flash Attention in Bagel & Support GQA in SDPA Backend#3728
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
| dropout_p=0.0, | ||
| is_causal=self.causal, | ||
| scale=self.softmax_scale, | ||
| enable_gqa=True, |
There was a problem hiding this comment.
Also @gcanlin, can you please take a look at this to make sure you are ok with this change? GQA support is needed for Bagel, but this is the only change in the attn backend 🙂 otherwise it'll crash with a dims mismatch
hsliuustc0106
left a comment
There was a problem hiding this comment.
Review cannot proceed until docs build passes. Please fix the docs build failure.
|
Thanks @hsliuustc0106, resolved the docs issue! |
|
@alex-jw-brooks Thanks! That's what I expect. I will check the details later. |
hsliuustc0106
left a comment
There was a problem hiding this comment.
Review Summary
Reviewed PR #3728: Remove Hardcoded Flash Attention in Bagel & Support GQA in SDPA Backend.
What I validated
- Review gates: DCO, pre-commit, build (3.11/3.12) all passing. Docs build still pending but author confirmed docs issue resolved.
- Correctness of attention refactor: The refactored
_forward_genand_forward_undcorrectly replaceflash_attn_varlen_funcwith the abstractedDiffusionAttentionlayer. The SP path preserves the joint mechanism. The non-SP gen path correctly concatenates text+vae tokens for bidirectional attention. The und path correctly handles KV cache appending and causal masking. - Multi-request batching concern investigated: The Bagel pipeline (
BagelPipeline.forward()) enforces single-request — it warns and truncates to 1 prompt if multiple are provided. All downstream calls (prepare_prompts,generate_text,generate_image) receive single-element inputs. The varlen→batched attention change is therefore safe in practice. The CFG batching ingenerate_imageis multi-branch (gen/cfg_text/cfg_img), not multi-request. - SDPA
enable_gqa=True: Safe for all models — no-op whennum_heads == num_kv_heads, correct expansion whennum_kv_heads < num_heads. PyTorch >= 2.5.0 requirement is met. Other models that pre-expand K/V (e.g., HunyuanImage3) will see it as a no-op. - Output difference: Author provides thorough evidence (visual comparison + minimal patch) attributing the change to
flash_attn_varlen_funcvsflash_attn_funcnumerical differences. This is a known artifact of switching between varlen and non-varlen Flash Attention APIs. - Test update:
test_trajectory_recording.pycorrectly updated to match new API (removed dead params). - Dead arg cleanup: The removal of
packed_query_indexes,key_values_lens,packed_key_value_indexesacross the entire call chain (model, pipeline, teacache extractor, tests) is consistent and thorough.
Non-blocking observations
-
_forward_unddocstring: Step numbering goes 1-6 then skips to 8 (line ~530 in new code). Trivial fix. -
Internal API vs actual usage mismatch: The model's internal API still maintains packed-format infrastructure (
packed_seqlens,packed_text_indexes,packed_vae_token_indexes) that suggests multi-request batching capability, but the pipeline enforces single-request. The author's TODO about removing packing is a good forward step — this would simplify the model significantly. -
enable_gqa=Truescope: This change applies to ALL diffusion models using the SDPA backend, not just Bagel. While it's safe (no-op for non-GQA models), it's worth being aware of the cross-model impact.
Verdict
No blocking issues found. The refactoring is a clear improvement in readability and backend flexibility. The evidence is solid. Approving from a correctness standpoint.
Reviewed by Claude Code with glm-5.1
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
|
Thanks @hsliuustc0106! Made the changes, should be ready for another look when you have a moment |
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b7e654ac47
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| attn_out = attn( | ||
| q.unsqueeze(0), | ||
| full_k.unsqueeze(0), | ||
| full_v.unsqueeze(0), |
There was a problem hiding this comment.
Preserve the cache offset when using causal SDPA
When past_key_values is populated and the selected backend is SDPA, this calls causal attention with query containing only the new tokens but key containing cache + new. PyTorch SDPA's is_causal=True uses an upper-left non-square mask, so a decode step with q_len=1 can only attend to the first cached key instead of all previous keys plus the current token. This breaks cached autoregressive text generation even for a single request; use an offset/lower-right causal bias or avoid SDPA's bare is_causal for cached decode.
Useful? React with 👍 / 👎.
| q = torch.cat([text_q, vae_q], dim=0).unsqueeze(0) | ||
| k = torch.cat([ctx_k, vae_k], dim=0).unsqueeze(0) | ||
| v = torch.cat([ctx_v, vae_v], dim=0).unsqueeze(0) | ||
| attn_out = self.attn_noncausal(q, k, v) |
There was a problem hiding this comment.
Keep batched CFG branches isolated in attention
In non-SP image generation with CFG enabled, _generate_image_single builds cfg_batched by concatenating the main/text/img branches and passes batched_query_lens, but this path flattens all text and VAE tokens into a single SDPA sequence and never uses those lengths. As a result, branch 0/1/2 tokens attend to each other instead of remaining independent CFG branches, corrupting the guidance prediction whenever batched CFG is used.
Useful? React with 👍 / 👎.
…GQA in SDPA Backend (#3728) Signed-off-by: Alex Brooks <albrooks@redhat.com>
Two divergences vs upstream Lance: 1. Initial noise mismatch — Bagel.prepare_input / Lance's prepare_video_latent call torch.randn(...) with no device or generator, falling back to CPU+fp32 + the global torch.manual_seed stream. Upstream samples directly on CUDA+bf16 via torch.Generator(device=cuda).manual_seed(seed) (lance.py:1536). That gives a totally different noise tensor for the same seed. Fix: regen packed_init_noises on-device with a fresh generator after prepare_vae_latent / prepare_video_latent in t2i, t2v, image_edit (i2v already had this). Result: CK1 byte-perfect (max_diff=0, cos=1.0). 2. Batched CFG attention leakage — Bagel.forward packs cond + cfg_text branches into a single LLM forward with concatenated KV caches. PR vllm-project#3728 removed the hardcoded flash_attn_varlen call and replaced it with DiffusionAttention, but the cu_seqlens-based block-diagonal mask was lost in translation. Without it, cond tokens attend to cfg_text tokens (and vice versa) on every layer, producing a model that no longer matches upstream's sequential per-branch forwards. Until the mask plumbing is restored, run each CFG branch sequentially through the single-forward path — matches upstream exactly. Result: layer00 input identical, CK7 cosine 0.186 -> 0.965. Also adds: - vllm_omni/diffusion/models/bagel/_dump_hooks.py — env-gated (LANCE_DUMP_DIR) layer-by-layer checkpoint dumper so future alignment work can run lance-upstream/lance_compare/compare.py side-by-side. Zero cost when the env is unset. - DiffusionEngine._dummy_run skips when LANCE_DUMP_DIR is set (the 512x512 warmup contaminates the dump harness's per-call state before the real request arrives). Verified all 7 Lance tasks visually: - t2i corgi astronaut now nearly identical to upstream - t2v red panda surfing - i2v snow leopard glacier - image_edit hat removed cleanly from portrait - video_edit red car on snowy road - x2t_image / x2t_video already aligned via the cache-offset SDPA fix. Signed-off-by: lishunyang12 <lishunyang12@163.com>
…GQA in SDPA Backend (vllm-project#3728) Signed-off-by: Alex Brooks <albrooks@redhat.com>
The refactoring in vllm-project#3728 replaced direct flash_attn_varlen_func calls with DiffusionAttention routing but used unsqueeze(0) for batch=1, causing all CFG branches to cross-attend as one sequence. Fix: build proper 4D (num_branches, seq, heads, dim) tensors with left-padded K and attention mask. The existing _forward_varlen_masked backend handles variable-length sequences correctly. Also re-enables the text2img shared memory test that was disabled by issue vllm-project#3977 due to this bug. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Alex Brooks <albrooks@redhat.com>
…GQA in SDPA Backend (vllm-project#3728) Signed-off-by: Alex Brooks <albrooks@redhat.com>
Purpose
Fix: #3301
Follow-up on Bagel fixes to support Bagel on other attention backends.
Outputs do change a bit, but I think this is due to the hardcoded
flash_attn_varlen_funcvsflash_attn_func, see details below for a minimal patch that'll flip the input to what this branch produces.Minimal test script:
On

mainwith no SP & with SP you should get:On this branch with no SP & with SP you should get:

You can also run it with
DIFFUSION_ATTENTION_BACKEND=TORCH_SDPAand make sure you get the same result.To confirm the difference is due to
flash_attn_varlen_funcvsflash_attn_funcand not something else, you can use this as a minimal patch onmain:You should see the output flip to what we get on this branch.
@princepride @Gaohan123 can you please take a look? Thanks!