Skip to content

[Bugfix/Feature] Remove Hardcoded Flash Attention in Bagel & Support GQA in SDPA Backend#3728

Merged
hsliuustc0106 merged 9 commits into
vllm-project:mainfrom
alex-jw-brooks:bagel_fa
May 29, 2026
Merged

[Bugfix/Feature] Remove Hardcoded Flash Attention in Bagel & Support GQA in SDPA Backend#3728
hsliuustc0106 merged 9 commits into
vllm-project:mainfrom
alex-jw-brooks:bagel_fa

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Collaborator

@alex-jw-brooks alex-jw-brooks commented May 19, 2026

Purpose

Fix: #3301
Follow-up on Bagel fixes to support Bagel on other attention backends.

  • Removes the hard-coded Flash attention code from the Bagel to support other backends
  • Removes dead args & refactors a bit to try to make the gen / und paths more readable
  • Allow GQA for SDPA backend (needed for Bagel)

Outputs do change a bit, but I think this is due to the hardcoded flash_attn_varlen_func vs flash_attn_func, see details below for a minimal patch that'll flip the input to what this branch produces.

Minimal test script:

import argparse
import os

OUTPUT_DIR = os.path.dirname(os.path.abspath(__file__))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", nargs="?", default="A cute cat sitting on a windowsill")
    parser.add_argument("--ulysses-degree", type=int, default=1)
    parser.add_argument("--ring-degree", type=int, default=1)
    return parser.parse_args()


def main():
    args = parse_args()

    from vllm_omni.entrypoints.omni import Omni
    has_parallel = args.ulysses_degree > 1 or args.ring_degree > 1
    parallel_kwargs = {} if not has_parallel else {
        "parallel_config": {
            "ulysses_degree": args.ulysses_degree,
            "ring_degree": args.ring_degree,
        }
    }

    omni = Omni(
        model="ByteDance-Seed/BAGEL-7B-MoT",
        **parallel_kwargs,
    )

    formatted_prompt = {
        "prompt": f"<|im_start|>{args.prompt}<|im_end|>",
        "modalities": ["image"],
    }

    params_list = omni.default_sampling_params_list
    omni_outputs = list(omni.generate(prompts=[formatted_prompt], sampling_params_list=params_list))

    ar_output = omni_outputs[0]
    ro = getattr(ar_output, 'request_output', None)
    if ro and getattr(ro, 'outputs', None):
        token_ids = [tid for o in ro.outputs for tid in (o.token_ids or [])]
        print(f"AR stage tokens: {token_ids}")

    save_path = os.path.join(OUTPUT_DIR, "output_ulysses_{}_ring_{}.png".format(args.ulysses_degree, args.ring_degree))
    omni_outputs[1].images[0].save(save_path)
    print(f"Saved {save_path}")


if __name__ == "__main__":
    main()

On main with no SP & with SP you should get:
output_ulysses_1_ring_1

On this branch with no SP & with SP you should get:
output_ulysses_1_ring_1_new
You can also run it with DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA and make sure you get the same result.

To confirm the difference is due to flash_attn_varlen_func vs flash_attn_func and not something else, you can use this as a minimal patch on main :

--- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py
+++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
@@ -616,14 +616,11 @@ class PackedAttentionMoT(nn.Module):
         cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
         cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
 
-        packed_attn_output = flash_attn_varlen_func(
-            q=packed_query_states,
-            k=merged_key_states,
-            v=merged_value_states,
-            cu_seqlens_q=cu_seqlens_q.to(torch.int32),
-            cu_seqlens_k=cu_seqlens_k.to(torch.int32),
-            max_seqlen_q=max(query_lens).item(),
-            max_seqlen_k=max(key_values_lens).item(),
+        from vllm_omni.diffusion.attention.backends.utils.fa import flash_attn_func
+        packed_attn_output = flash_attn_func(
+            packed_query_states.unsqueeze(0),
+            merged_key_states.unsqueeze(0),
+            merged_value_states.unsqueeze(0),
             causal=is_causal,
         )
         packed_attn_output = packed_attn_output.reshape(-1, self.q_size)

You should see the output flip to what we get on this branch.

@princepride @Gaohan123 can you please take a look? Thanks!

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

dropout_p=0.0,
is_causal=self.causal,
scale=self.softmax_scale,
enable_gqa=True,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also @gcanlin, can you please take a look at this to make sure you are ok with this change? GQA support is needed for Bagel, but this is the only change in the attn backend 🙂 otherwise it'll crash with a dims mismatch

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review cannot proceed until docs build passes. Please fix the docs build failure.

@alex-jw-brooks
Copy link
Copy Markdown
Collaborator Author

Thanks @hsliuustc0106, resolved the docs issue!

@gcanlin gcanlin self-assigned this May 21, 2026
@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented May 21, 2026

@alex-jw-brooks Thanks! That's what I expect. I will check the details later.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

Reviewed PR #3728: Remove Hardcoded Flash Attention in Bagel & Support GQA in SDPA Backend.

What I validated

  • Review gates: DCO, pre-commit, build (3.11/3.12) all passing. Docs build still pending but author confirmed docs issue resolved.
  • Correctness of attention refactor: The refactored _forward_gen and _forward_und correctly replace flash_attn_varlen_func with the abstracted DiffusionAttention layer. The SP path preserves the joint mechanism. The non-SP gen path correctly concatenates text+vae tokens for bidirectional attention. The und path correctly handles KV cache appending and causal masking.
  • Multi-request batching concern investigated: The Bagel pipeline (BagelPipeline.forward()) enforces single-request — it warns and truncates to 1 prompt if multiple are provided. All downstream calls (prepare_prompts, generate_text, generate_image) receive single-element inputs. The varlen→batched attention change is therefore safe in practice. The CFG batching in generate_image is multi-branch (gen/cfg_text/cfg_img), not multi-request.
  • SDPA enable_gqa=True: Safe for all models — no-op when num_heads == num_kv_heads, correct expansion when num_kv_heads < num_heads. PyTorch >= 2.5.0 requirement is met. Other models that pre-expand K/V (e.g., HunyuanImage3) will see it as a no-op.
  • Output difference: Author provides thorough evidence (visual comparison + minimal patch) attributing the change to flash_attn_varlen_func vs flash_attn_func numerical differences. This is a known artifact of switching between varlen and non-varlen Flash Attention APIs.
  • Test update: test_trajectory_recording.py correctly updated to match new API (removed dead params).
  • Dead arg cleanup: The removal of packed_query_indexes, key_values_lens, packed_key_value_indexes across the entire call chain (model, pipeline, teacache extractor, tests) is consistent and thorough.

Non-blocking observations

  1. _forward_und docstring: Step numbering goes 1-6 then skips to 8 (line ~530 in new code). Trivial fix.

  2. Internal API vs actual usage mismatch: The model's internal API still maintains packed-format infrastructure (packed_seqlens, packed_text_indexes, packed_vae_token_indexes) that suggests multi-request batching capability, but the pipeline enforces single-request. The author's TODO about removing packing is a good forward step — this would simplify the model significantly.

  3. enable_gqa=True scope: This change applies to ALL diffusion models using the SDPA backend, not just Bagel. While it's safe (no-op for non-GQA models), it's worth being aware of the cross-model impact.

Verdict

No blocking issues found. The refactoring is a clear improvement in readability and backend flexibility. The evidence is solid. Approving from a correctness standpoint.


Reviewed by Claude Code with glm-5.1

Comment thread vllm_omni/diffusion/attention/backends/sdpa.py Outdated
Comment thread vllm_omni/diffusion/models/bagel/bagel_transformer.py
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks
Copy link
Copy Markdown
Collaborator Author

Thanks @hsliuustc0106! Made the changes, should be ready for another look when you have a moment

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented May 28, 2026

@codex review

Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b7e654ac47

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +565 to +568
attn_out = attn(
q.unsqueeze(0),
full_k.unsqueeze(0),
full_v.unsqueeze(0),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve the cache offset when using causal SDPA

When past_key_values is populated and the selected backend is SDPA, this calls causal attention with query containing only the new tokens but key containing cache + new. PyTorch SDPA's is_causal=True uses an upper-left non-square mask, so a decode step with q_len=1 can only attend to the first cached key instead of all previous keys plus the current token. This breaks cached autoregressive text generation even for a single request; use an offset/lower-right causal bias or avoid SDPA's bare is_causal for cached decode.

Useful? React with 👍 / 👎.

Comment on lines +489 to +492
q = torch.cat([text_q, vae_q], dim=0).unsqueeze(0)
k = torch.cat([ctx_k, vae_k], dim=0).unsqueeze(0)
v = torch.cat([ctx_v, vae_v], dim=0).unsqueeze(0)
attn_out = self.attn_noncausal(q, k, v)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep batched CFG branches isolated in attention

In non-SP image generation with CFG enabled, _generate_image_single builds cfg_batched by concatenating the main/text/img branches and passes batched_query_lens, but this path flattens all text and VAE tokens into a single SDPA sequence and never uses those lengths. As a result, branch 0/1/2 tokens attend to each other instead of remaining independent CFG branches, corrupting the guidance prediction whenever batched CFG is used.

Useful? React with 👍 / 👎.

@lishunyang12 lishunyang12 added the ready label to trigger buildkite CI label May 28, 2026
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 merged commit 68a1d23 into vllm-project:main May 29, 2026
7 of 8 checks passed
tzhouam pushed a commit that referenced this pull request May 29, 2026
…GQA in SDPA Backend (#3728)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
david6666666 pushed a commit to lishunyang12/vllm-omni that referenced this pull request May 31, 2026
Two divergences vs upstream Lance:

1. Initial noise mismatch — Bagel.prepare_input / Lance's
   prepare_video_latent call torch.randn(...) with no device or
   generator, falling back to CPU+fp32 + the global torch.manual_seed
   stream.  Upstream samples directly on CUDA+bf16 via
   torch.Generator(device=cuda).manual_seed(seed) (lance.py:1536).
   That gives a totally different noise tensor for the same seed.  Fix:
   regen packed_init_noises on-device with a fresh generator after
   prepare_vae_latent / prepare_video_latent in t2i, t2v, image_edit
   (i2v already had this).  Result: CK1 byte-perfect (max_diff=0, cos=1.0).

2. Batched CFG attention leakage — Bagel.forward packs cond + cfg_text
   branches into a single LLM forward with concatenated KV caches.
   PR vllm-project#3728 removed the hardcoded flash_attn_varlen call and replaced
   it with DiffusionAttention, but the cu_seqlens-based block-diagonal
   mask was lost in translation.  Without it, cond tokens attend to
   cfg_text tokens (and vice versa) on every layer, producing a model
   that no longer matches upstream's sequential per-branch forwards.
   Until the mask plumbing is restored, run each CFG branch
   sequentially through the single-forward path — matches upstream
   exactly.  Result: layer00 input identical, CK7 cosine 0.186 -> 0.965.

Also adds:

- vllm_omni/diffusion/models/bagel/_dump_hooks.py — env-gated
  (LANCE_DUMP_DIR) layer-by-layer checkpoint dumper so future
  alignment work can run lance-upstream/lance_compare/compare.py
  side-by-side.  Zero cost when the env is unset.
- DiffusionEngine._dummy_run skips when LANCE_DUMP_DIR is set
  (the 512x512 warmup contaminates the dump harness's per-call state
  before the real request arrives).

Verified all 7 Lance tasks visually:
- t2i      corgi astronaut now nearly identical to upstream
- t2v      red panda surfing
- i2v      snow leopard glacier
- image_edit  hat removed cleanly from portrait
- video_edit  red car on snowy road
- x2t_image / x2t_video already aligned via the cache-offset SDPA fix.

Signed-off-by: lishunyang12 <lishunyang12@163.com>
zengchuang-hw pushed a commit to zengchuang-hw/vllm-omni that referenced this pull request Jun 1, 2026
…GQA in SDPA Backend (vllm-project#3728)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
alex-jw-brooks added a commit to alex-jw-brooks/vllm-omni that referenced this pull request Jun 2, 2026
The refactoring in vllm-project#3728 replaced direct flash_attn_varlen_func calls
with DiffusionAttention routing but used unsqueeze(0) for batch=1,
causing all CFG branches to cross-attend as one sequence.

Fix: build proper 4D (num_branches, seq, heads, dim) tensors with
left-padded K and attention mask. The existing _forward_varlen_masked
backend handles variable-length sequences correctly.

Also re-enables the text2img shared memory test that was disabled
by issue vllm-project#3977 due to this bug.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Alex Brooks <albrooks@redhat.com>
86MaxCao pushed a commit to 86MaxCao/vllm-omni that referenced this pull request Jun 4, 2026
…GQA in SDPA Backend (vllm-project#3728)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Bagel Explicitly Requires Flash Attention

5 participants