Skip to content

[Perf/Fix] Reimplement Batched CFG Forward for Bagel#4098

Open
alex-jw-brooks wants to merge 12 commits into
vllm-project:mainfrom
alex-jw-brooks:bagel_fixes
Open

[Perf/Fix] Reimplement Batched CFG Forward for Bagel#4098
alex-jw-brooks wants to merge 12 commits into
vllm-project:mainfrom
alex-jw-brooks:bagel_fixes

Conversation

@alex-jw-brooks

Copy link
Copy Markdown
Collaborator

Purpose

Related: #3977

#3728 had removed the hard-coded flash attention code from Bagel, but broke the batched forward for when we have multiple CFG branches, largely because the text_cfg branch doesn't have kv values, so the kvs across the branches are uneven and incorrectly handled.

The PR for Lance fixed the correctness by calling the CFG branches sequentially, so the outputs on main should be correct for gen mode on Bagel, but the tests are still disabled, and the test pixel values have not been updated to reflect some changes made in the Lance PR.

We should probably merge #4081 first before this PR, since it updates the ground truth pixels and turns the e2e bagel tti/i2i tests back on, and we can validate that this PR won't change the outputs. For testing though, I've copied the updated pixel values over to this PR also.

Test Plan

Will add some more details with testing & examples tomorrow.

CC @Gaohan123 @lishunyang12 @zhangj1an @natureofnature @princepride

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@Gaohan123 Gaohan123 added this to the v0.22.0 milestone Jun 3, 2026
@Gaohan123

Copy link
Copy Markdown
Collaborator

@zhangj1an @princepride PTAL

@zhangj1an

zhangj1an commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

LGTM, is good to merge,

  • Previously, all CFG branches' tokens in Bagel were concatenated into a single sequence. In a forward step, a single dense attention was used, which caused the branches to cross-attend and contaminate each other. This PR instead stacks the branches on the batch dimension (q_4d/k_4d/v_4d), so it is equivalent to each branch using a separate matrix multiplication. So each branch only attends to its own cache + text + vae.

  • This branch did not re-use CFGParallelMixin. This is because this batched-CFG-with-KV-cache logic is specific to AR models like Bagel/Lance (per-branch KV cache, fused into one batched forward). CFGParallelMixin is meant for DiT models (assumes stateless, independent per-branch forwards), which do not deal with KV cache.

example of main branch v.s. this branch

main branch (token-concatenation) vs. this branch (batch-dim stacking)

Assume 2 CFG branches, each with 2 query tokens (real Bagel tokens are [text…, vae…] per branch, but 2 each is enough to show the idea):

  • Branch A (conditional): tokens a1, a2
  • Branch B (unconditional / text_cfg): tokens b1, b2
  1. previously, the tokens are concatenated into one sequence (batch 1, seq 4):
q = [a1, a2, b1, b2]   ->  shape (1, 4, d)
k = [a1, a2, b1, b2]   ->  shape (1, 4, d)

q @ kᵀ is a 4×4 score matrix.

          key: a1   a2   b1   b2
   q a1  [     ok   ok   XX   XX  ]
     a2  [     ok   ok   XX   XX  ]
     b1  [     XX   XX   ok   ok  ]
     b2  [     XX   XX   ok   ok  ]

The XX cells should not be there, because CFG branches should not attend to each other.

  1. in this PR, the tokens are stacked on the batch dimension. Each branch is its own batch row (batch 2, seq 2):
q_4d = [ [a1, a2],     ->  shape (2, 2, d)
         [b1, b2] ]
k_4d = [ [a1, a2],
         [b1, b2] ]

Batched attention runs q @ kᵀ independently per row. This results in two separate 2×2 matrices. This ensures each branch is independent, and is more light-weight than my previous proposed method (still use 1 huge matrix, just add diagonal masks).

   batch row 0 (branch A)        batch row 1 (branch B)
        key: a1   a2                  key: b1   b2
  q a1 [    ok   ok ]           q b1 [    ok   ok ]
    a2 [    ok   ok ]             b2 [    ok   ok ]

@Gaohan123 Gaohan123 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Could you please post the generation result before and after the PR to ituitively show the correction of images? And it is better to provide a table to compare the performance before and after the PR.

{"position": (400, 700), "rgb": (130, 96, 77)},
{"position": (700, 700), "rgb": (247, 203, 140)},
{"position": (256, 256), "rgb": (167, 156, 150)},
{"position": (100, 100), "rgb": (64, 45, 35)},

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean previous groundtruth is wrong? cc @princepride ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch, following @princepride's reply in #4081, we agreed to not change ref img pixels, it was already correct, so I think @alex-jw-brooks will undo this part in test_bagel_mooncake_connector.py and test_bagel_shared_memory_connector.py

@Gaohan123 Gaohan123 added the ready label to trigger buildkite CI label Jun 4, 2026
@princepride

Copy link
Copy Markdown
Collaborator

I don't think the previous pixels are wrong😑

Signed-off-by: Alex Brooks <albrooks@redhat.com>

minor

Signed-off-by: Alex Brooks <albrooks@redhat.com>

fix ref

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks

alex-jw-brooks commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator Author

Hey @princepride @Gaohan123 @zhangj1an Yes, when I opened this PR, it had just been updated to reflect current main at the time, which is why the pixel values were changed 😅

Rebased now after the discussions from #4081, since the latent generation stuff was reverted, the old values are correct

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants