Skip to content

[Core] Support Async & Sync AutoRegressive Scheduling#3306

Merged
Gaohan123 merged 7 commits into
vllm-project:mainfrom
alex-jw-brooks:async_scheduler
May 5, 2026
Merged

[Core] Support Async & Sync AutoRegressive Scheduling#3306
Gaohan123 merged 7 commits into
vllm-project:mainfrom
alex-jw-brooks:async_scheduler

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

@alex-jw-brooks alex-jw-brooks commented May 2, 2026

Purpose

Fixes the ongoing Bagel failures for img2img: #3268

The superclass for the autoregressive scheduler was recently switched from vLLM's Scheduler to AsyncScheduler as part of the optimizations in #3203. The scheduler change is actually the root cause of the Bagel pixel mismatches, because it changes the output of the AR component.

Since the async_scheduling is passable in the config, we should properly support both synchronous and asynchronous scheduling, which is most easily accomplished by just making the current AR scheduler abstract, and creating synchronous / asynchronous subclasses that just inherit from the corresponding vLLM class. Note that async AR scheduling is enabled by default since this aligns with vLLM's behavior and the recent optimizations.

Test Plan

The changes are relatively small since it's mostly adding two subclasses that only differ in one they inherit from and some of the related plumbing. The bagel test that is failing on main should now pass due to async_scheduling": False in the overlay, which will delegate to OmniARScheduler. You should also be able to reproduce the pixel mismatch by removing the explict sync scheduling from the overlay.

Test Result

Bagel test failing in the CI is fixed, and scheduling can now be set to sync / async as expected for AR stages, rather than having the scheduler inherit from only one of the scheduler classes.

CC @hsliuustc0106 @princepride @amy-why-3459 @lishunyang12 @yenuo26

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@alex-jw-brooks alex-jw-brooks changed the title [Core] Support Async & Sync AutoRegressive Scheduling [Core/CI] Support Async & Sync AutoRegressive Scheduling May 2, 2026
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCKING:

  • Breaking Changes — Behavior change without migration guide. Default changes output for AR stages (this is the fix for Bagel pixel mismatches, but users may be affected). Add a migration guide or at least a note in PR description about what users should expect.

Non-blocking:

  • Documentation — Deprecation warnings for are good. Consider documenting the new config option in the deploy config docs.

hsliuustc0106

This comment was marked as duplicate.

@hsliuustc0106 hsliuustc0106 added this to the v0.20.0 milestone May 2, 2026
@hsliuustc0106 hsliuustc0106 added ready label to trigger buildkite CI merge-test label to trigger buildkite merge test CI and removed ready label to trigger buildkite CI labels May 2, 2026
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@princepride I thought async_scheduling is a default selection for AR part, do you know why bagel ar part cannot be serving with aysnc scheduling with accuracy?

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@alex-jw-brooks thanks, let's try the CI

@amy-why-3459
Copy link
Copy Markdown
Contributor

@princepride I thought async_scheduling is a default selection for AR part, do you know why bagel ar part cannot be serving with aysnc scheduling with accuracy?

Thank you so much for your quick fix. I also think we need to figure out if bagel is unable to use async_scheduling? Why is bagel unable to use async_scheduling? I tested the results of bagel and it seems to be normal; it's just that the threshold needs to be relaxed.

@princepride
Copy link
Copy Markdown
Collaborator

😂, I also want know why async or sync will result model get different result.

@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented May 3, 2026

Hey @hsliuustc0106 @amy-why-3459 @princepride I am still working on understanding the root cause and the right fix here, still learning a lot about scheduling. 😅

I am pretty sure what is happening is that we should be doing the kv transfer after prefill of the first stage, but with async scheduling, it is scheduling + executing a decode that is polluting the transferred states, which is why the outputs are wrong in this case, but maybe not as visible in other models.

Will continue looking into this tomorrow - I have some ideas for how to fix it I think, but need to make sure it doesn't have weird implications for normal prefill/decode handling

Comment thread vllm_omni/core/sched/omni_ar_scheduler.py Outdated
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
# async scheduling, since both exclude async placeholders. We use
# seq_len since we control it, just in case upstream async scheduler
# semantics change in the future.
num_computed = data.get("seq_len")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly for safety since this is what is set when we mark the request for transfer

# Output placeholders are zero when async scheduling isn't used
return request.num_computed_tokens - request.num_output_placeholders

def _update_request_with_output(self, request: Request, new_token_ids: list[int]) -> tuple[list[int], bool]:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented May 3, 2026

@hsliuustc0106 @princepride @lishunyang12 @amy-why-3459 this is ready for another look. I think it is straightforward now after #3318 was merged.

Also, the reason for the outputs going off so far in the noise image were mostly due to the prefill/decode rope handling in the Bagel model, which @princepride's PR fixed.

            if num_computed_tokens > prefill_rope:
                meta["ropes"] = [num_computed_tokens]

^ In the example, num_computed_tokens is ~7600, but prefill rope is quite small (~10), so that shift from rope position is what is messing things up the most. The async placeholders are correct now as well, but mishandling the async placeholders would shift from 10 -> 11, which is not visually discernable.

In case anyone else was worried about one placeholder affecting the i2i result to such an extreme degree 😅

"ropes": [rope],
"image_shape": [img_H, img_W],
"prefill_position_count": int(end - start),
"prefill_position_count": req_len,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused. So your mean is that with this modification, bagel can work correctly under async scheduling?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if so, we can revert the threshold from 10 to 5 and test it again. Thanks

Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks alex-jw-brooks May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, we should be able to use 5 now, just pushed to reduce it. This line is just a refactor, since req_len = end - start is already set earlier. The main related fix is in this related PR #3318 from @princepride, which fixed async scheduling counts. But we still need this PR to allow disabling async scheduling correctly for AR models.

Basically the problem was conflating token counts with positions in RoPE. In this case, the new_positions for RoPE should be something like this:

 [0, 0, 0, ..., 0, 0, 0, 0, 1, 1, 1, 1, ..., 1, 1, 2, 3, 4, 5, 6, 7, 8, 9] # len is ~7600

because visual components share the same position. So then for this example, num_post_text is 8, and the "ropes" is 10 (M=0, so it's 2 + num_post_text).

In the previous code, if an async placeholder was let through, we'd get something like num_computed_tokens=7601, and do this.

            if num_computed_tokens > prefill_rope:
                meta["ropes"] = [num_computed_tokens]

but num_computed_tokens is tokens, not positions. So it would overwrite that 10 -> 7601 because it was not considering that those positions should be shared, which would completely destroy outputs.

With the fix the on model side of things, letting an async placeholder instead looks like a decode, so it would overwrite 10 -> 11. However, as these are very close, it doesn't have a dramatic impact. This is now fixed too though.

There are some things that still look a little weird to me in the AR part of Bagel, but at least positions now match synchronous so results are fine, so I think this PR should be fine now

Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks alex-jw-brooks changed the title [Core/CI] Support Async & Sync AutoRegressive Scheduling [Core] Support Async & Sync AutoRegressive Scheduling May 4, 2026
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

Looks like the current failure is unrelated, but things are green again with tolerance reduced to 5 for the bagel tests @Gaohan123 🙂

Copy link
Copy Markdown
Collaborator

@Gaohan123 Gaohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your explanation. It is clear for me now. LGTM.

@Gaohan123 Gaohan123 enabled auto-merge (squash) May 4, 2026 16:10
@Gaohan123 Gaohan123 disabled auto-merge May 4, 2026 16:27
@Gaohan123
Copy link
Copy Markdown
Collaborator

https://buildkite.com/vllm/vllm-omni/builds/8842/canvas?sid=019df3c9-8999-4c80-bfb8-7bb24990189d&tab=output
It seems 5 is still a bit strict... Maybe possible 1 token shift also takes effect. 10 is ok I think

Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented May 4, 2026

Ah it looks like a different test, it's TTI and this has been fixing I2I 😞 sounds good though, I reverted it back to 10.

It could also be due to changes in flash attention backend since Bagel explicitly depends on it. I can't reproduce it locally, so maybe there are slight changes in which FA package is used 🤔 I was going to remove the explicit FA stuff in Bagel when I have time, so we can reduce it there if we find anything else that causes a divergence 🤞

@Gaohan123 Gaohan123 merged commit bb239fa into vllm-project:main May 5, 2026
7 of 8 checks passed
@yenuo26 yenuo26 linked an issue May 6, 2026 that may be closed by this pull request
1 task
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-test label to trigger buildkite merge test CI ready label to trigger buildkite CI

Projects

None yet

6 participants