[Perf/Fix] Reimplement Batched CFG Forward for Bagel#4098
[Perf/Fix] Reimplement Batched CFG Forward for Bagel#4098alex-jw-brooks wants to merge 12 commits into
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
@zhangj1an @princepride PTAL |
|
LGTM, is good to merge,
example of main branch v.s. this branchmain branch (token-concatenation) vs. this branch (batch-dim stacking)Assume 2 CFG branches, each with 2 query tokens (real Bagel tokens are
The
Batched attention runs |
Gaohan123
left a comment
There was a problem hiding this comment.
Thanks. Could you please post the generation result before and after the PR to ituitively show the correction of images? And it is better to provide a table to compare the performance before and after the PR.
| {"position": (400, 700), "rgb": (130, 96, 77)}, | ||
| {"position": (700, 700), "rgb": (247, 203, 140)}, | ||
| {"position": (256, 256), "rgb": (167, 156, 150)}, | ||
| {"position": (100, 100), "rgb": (64, 45, 35)}, |
There was a problem hiding this comment.
Does it mean previous groundtruth is wrong? cc @princepride ?
There was a problem hiding this comment.
Thanks for the catch, following @princepride's reply in #4081, we agreed to not change ref img pixels, it was already correct, so I think @alex-jw-brooks will undo this part in test_bagel_mooncake_connector.py and test_bagel_shared_memory_connector.py
|
I don't think the previous pixels are wrong😑 |
Signed-off-by: Alex Brooks <albrooks@redhat.com> minor Signed-off-by: Alex Brooks <albrooks@redhat.com> fix ref Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
530fa35 to
4666ee3
Compare
|
Hey @princepride @Gaohan123 @zhangj1an Yes, when I opened this PR, it had just been updated to reflect current Rebased now after the discussions from #4081, since the latent generation stuff was reverted, the old values are correct |
Purpose
Related: #3977
#3728 had removed the hard-coded flash attention code from Bagel, but broke the batched forward for when we have multiple CFG branches, largely because the
text_cfgbranch doesn't have kv values, so the kvs across the branches are uneven and incorrectly handled.The PR for Lance fixed the correctness by calling the CFG branches sequentially, so the outputs on
mainshould be correct forgenmode on Bagel, but the tests are still disabled, and the test pixel values have not been updated to reflect some changes made in the Lance PR.We should probably merge #4081 first before this PR, since it updates the ground truth pixels and turns the e2e bagel tti/i2i tests back on, and we can validate that this PR won't change the outputs. For testing though, I've copied the updated pixel values over to this PR also.
Test Plan
Will add some more details with testing & examples tomorrow.
CC @Gaohan123 @lishunyang12 @zhangj1an @natureofnature @princepride