Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions tests/test_config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pytest

from vllm_omni.config.stage_config import (
_EXECUTION_TYPE_TO_SCHEDULER,
_PIPELINE_REGISTRY,
ModelPipeline,
PipelineConfig,
Expand All @@ -20,6 +19,7 @@
StageExecutionType,
StagePipelineConfig,
StageType,
_resolve_scheduler,
build_stage_runtime_overrides,
register_pipeline,
strip_parent_engine_args,
Expand Down Expand Up @@ -768,25 +768,29 @@ def test_validate_no_stages(self):
p = PipelineConfig(model_type="t", model_arch="A")
assert any("no stages" in e.lower() for e in p.validate())

def test_get_scheduler_cls(self):
p = PipelineConfig(
model_type="t",
model_arch="A",
stages=(
StagePipelineConfig(stage_id=0, model_stage="a", execution_type=StageExecutionType.LLM_AR),
StagePipelineConfig(
stage_id=1, model_stage="b", execution_type=StageExecutionType.LLM_GENERATION, input_sources=(0,)
),
),
)
assert "OmniARScheduler" in p.get_scheduler_cls(0)
assert "OmniGenerationScheduler" in p.get_scheduler_cls(1)


class TestExecutionTypeToScheduler:
def test_all_types_mapped(self):
class TestResolveScheduler:
def test_all_execution_types_handled(self):
for et in StageExecutionType:
assert et in _EXECUTION_TYPE_TO_SCHEDULER
_resolve_scheduler(et)

def test_ar_sync_when_false(self):
cls = _resolve_scheduler(StageExecutionType.LLM_AR, async_scheduling=False)
assert cls is not None
assert "Async" not in cls.__name__

def test_ar_async_when_true(self):
cls = _resolve_scheduler(StageExecutionType.LLM_AR, async_scheduling=True)
assert cls is not None
assert "Async" in cls.__name__

def test_generation(self):
cls = _resolve_scheduler(StageExecutionType.LLM_GENERATION)
assert cls is not None
assert "Generation" in cls.__name__

def test_diffusion_returns_none(self):
assert _resolve_scheduler(StageExecutionType.DIFFUSION) is None


class TestPipelineRegistry:
Expand Down
69 changes: 45 additions & 24 deletions vllm_omni/config/stage_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from typing import Any

from vllm.logger import init_logger
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler

from vllm_omni.config.yaml_util import create_config, load_yaml_config, to_dict
from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler
from vllm_omni.core.sched.omni_ar_scheduler import OmniARAsyncScheduler, OmniARScheduler
from vllm_omni.core.sched.omni_generation_scheduler import OmniGenerationScheduler

_MODELS_DIR = Path(__file__).resolve().parent.parent / "model_executor" / "models"
Expand Down Expand Up @@ -136,18 +137,27 @@ class StageExecutionType(str, Enum):
DIFFUSION = "diffusion"


# Mapping class refs (not dotted-path strings) so module/class renames fail
# at import time instead of lazily at scheduler resolution. YAML overrides
# and downstream serialization still use the dotted-path string form; the
# conversion happens at the map lookup site via _scheduler_path().
_EXECUTION_TYPE_TO_SCHEDULER: dict[StageExecutionType, type | None] = {
StageExecutionType.LLM_AR: OmniARScheduler,
StageExecutionType.LLM_GENERATION: OmniGenerationScheduler,
StageExecutionType.DIFFUSION: None,
}
def _resolve_scheduler(
execution_type: StageExecutionType,
async_scheduling: bool = True,
) -> type[VLLMScheduler] | None:
"""Return the scheduler class for the given execution_type.

NOTE: For AutoRegressive stages, we have two schedulers for sync / async
respectively, and decide which to used based on the value of async_scheduling.
For other execution types, async_scheduling is not used.
"""
if execution_type == StageExecutionType.LLM_AR:
if not async_scheduling:
return OmniARScheduler
return OmniARAsyncScheduler
if execution_type == StageExecutionType.LLM_GENERATION:
return OmniGenerationScheduler
# Diffusion currently returns None here.
return None


def _scheduler_path(cls: type | None) -> str | None:
def _scheduler_path(cls: type[VLLMScheduler] | None) -> str | None:
"""Return the dotted import path for a scheduler class (``None`` passes through)."""
if cls is None:
return None
Expand Down Expand Up @@ -213,18 +223,6 @@ def get_stage(self, stage_id: int) -> StagePipelineConfig | None:
return stage
return None

def get_scheduler_cls(self, stage_id: int) -> str | None:
"""Return the inferred scheduler class path for a stage.

Returns ``None`` for DIFFUSION stages (no vLLM scheduler). Raises
``ValueError`` if ``stage_id`` doesn't exist in this pipeline, and
``KeyError`` if ``execution_type`` isn't in the scheduler map.
"""
stage = self.get_stage(stage_id)
if stage is None:
raise ValueError(f"Pipeline {self.model_type!r} has no stage with id {stage_id}")
return _scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER[stage.execution_type])

def validate(self) -> list[str]:
"""Return list of topology errors (empty if valid)."""
errors: list[str] = []
Expand Down Expand Up @@ -796,6 +794,12 @@ def merge_pipeline_deploy(
stage_type, worker_type = _resolve_execution_mode(ps.execution_type)
input_proc, next_stage_proc = _select_processor_funcs(ps, deploy.async_chunk)
engine_args = _build_engine_args(ps, ds, pipeline, deploy, next_stage_proc)
sched_cls = _resolve_scheduler(
ps.execution_type,
engine_args.get("async_scheduling", True),
)
if ps.execution_type == StageExecutionType.LLM_AR:
engine_args["async_scheduling"] = sched_cls is OmniARAsyncScheduler
extras = _build_extras(ps, ds)
runtime: dict[str, Any] = {"process": True}
if ds is not None:
Expand All @@ -811,7 +815,7 @@ def merge_pipeline_deploy(
final_output=ps.final_output,
final_output_type=ps.final_output_type,
worker_type=worker_type,
scheduler_cls=_scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER.get(ps.execution_type)),
scheduler_cls=_scheduler_path(sched_cls),
hf_config_name=ps.hf_config_name,
is_comprehension=ps.owns_tokenizer,
yaml_engine_args=engine_args,
Expand Down Expand Up @@ -1280,6 +1284,23 @@ def _parse_pipeline_yaml(cls, path: Path, model_type: str) -> ModelPipeline:
# YAMLs (worker_type, scheduler_cls, etc.) — read from both places.
worker_type = stage_data.get("worker_type", None) or yaml_engine_args.pop("worker_type", None)
scheduler_cls = stage_data.get("scheduler_cls", None) or yaml_engine_args.pop("scheduler_cls", None)
if scheduler_cls:
async_sched = yaml_engine_args.get("async_scheduling")
if async_sched is not None:
logger.warning(
"Stage %s: async_scheduling=%r and scheduler_cls=%r "
"should not be set together. scheduler_cls will take "
"precedence for which scheduler is used.",
stage_data.stage_id,
async_sched,
scheduler_cls,
)
else:
logger.warning(
"Stage %s: scheduler_cls=%r is deprecated. Use async_scheduling instead.",
stage_data.stage_id,
scheduler_cls,
)
hf_config_name = stage_data.get("hf_config_name", None) or yaml_engine_args.pop("hf_config_name", None)
model_stage = getattr(stage_data, "model_stage", None) or yaml_engine_args.pop("model_stage", None)

Expand Down
3 changes: 2 additions & 1 deletion vllm_omni/core/sched/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
Scheduling components for vLLM-Omni.
"""

from .omni_ar_scheduler import OmniARScheduler
from .omni_ar_scheduler import OmniARAsyncScheduler, OmniARScheduler
from .omni_generation_scheduler import OmniGenerationScheduler
from .output import OmniNewRequestData

__all__ = [
"OmniARAsyncScheduler",
"OmniARScheduler",
"OmniGenerationScheduler",
"OmniNewRequestData",
Expand Down
39 changes: 10 additions & 29 deletions vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from vllm.distributed.kv_events import KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
from vllm.v1.core.sched.async_scheduler import AsyncScheduler as VLLMScheduler
from vllm.v1.core.sched.async_scheduler import AsyncScheduler as AsyncVLLMScheduler
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as SyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
Expand Down Expand Up @@ -41,12 +41,9 @@ def to_dict(self) -> dict[str, Any]:


class OmniARScheduler(OmniSchedulerMixin, VLLMScheduler):
"""
OmniARScheduler: Scheduler for vLLM-Omni multimodal processing.

This scheduler extends vLLM's scheduler to support multimodal and
non-autoregressive processing with additional fields and methods
specific to vLLM-Omni.
"""Synchronous AutoRegressive scheduler for vLLM-Omni. This class is also
used as a base class for the OmniARAsyncScheduler and holds most of the
core scheduling logic.
"""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -83,29 +80,9 @@ def __init__(self, *args, **kwargs):

def _get_confirmed_num_computed_tokens(self, request: Request) -> int:
"""num_computed_tokens minus async placeholders (KV actually on GPU)."""
# Output placeholders are zero when async scheduling isn't used
return request.num_computed_tokens - request.num_output_placeholders

def _update_request_with_output(self, request: Request, new_token_ids: list[int]) -> tuple[list[int], bool]:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Append output tokens, then cache blocks up to the confirmed count
so KV transfer never sees blocks whose data has not been computed yet.
"""
if request.discard_latest_async_tokens:
request.discard_latest_async_tokens = False
return [], False

status_before_update = request.status

new_token_ids, stopped = SyncScheduler._update_request_with_output(self, request, new_token_ids)

request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0

if status_before_update == RequestStatus.RUNNING:
confirmed = self._get_confirmed_num_computed_tokens(request)
self.kv_cache_manager.cache_blocks(request, confirmed)

return new_token_ids, stopped

def _get_kv_transfer_criteria(self) -> dict | None:
# Note: vllm_config is available in Scheduler after super().__init__
if not hasattr(self, "vllm_config"):
Expand Down Expand Up @@ -786,3 +763,7 @@ def get_finished_requests_needing_kv_transfer(self) -> dict[str, dict]:

self.requests_needing_kv_transfer.clear()
return requests


class OmniARAsyncScheduler(OmniARScheduler, AsyncVLLMScheduler):
"""Asynchronous AutoRegressive scheduler."""
3 changes: 2 additions & 1 deletion vllm_omni/model_executor/models/bagel/bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def get_kv_transfer_metadata(
*,
num_computed_tokens: int | None = None,
) -> dict[str, Any] | None:
# NOTE: num_computed_tokens will not include async placeholders
meta = self._ropes_metadata.pop(req_id, None)
if meta is None:
return None
Expand Down Expand Up @@ -854,7 +855,7 @@ def _adjust_positions_for_img2img(
{
"ropes": [rope],
"image_shape": [img_H, img_W],
"prefill_position_count": int(end - start),
"prefill_position_count": req_len,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused. So your mean is that with this modification, bagel can work correctly under async scheduling?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if so, we can revert the threshold from 10 to 5 and test it again. Thanks

Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks alex-jw-brooks May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, we should be able to use 5 now, just pushed to reduce it. This line is just a refactor, since req_len = end - start is already set earlier. The main related fix is in this related PR #3318 from @princepride, which fixed async scheduling counts. But we still need this PR to allow disabling async scheduling correctly for AR models.

Basically the problem was conflating token counts with positions in RoPE. In this case, the new_positions for RoPE should be something like this:

 [0, 0, 0, ..., 0, 0, 0, 0, 1, 1, 1, 1, ..., 1, 1, 2, 3, 4, 5, 6, 7, 8, 9] # len is ~7600

because visual components share the same position. So then for this example, num_post_text is 8, and the "ropes" is 10 (M=0, so it's 2 + num_post_text).

In the previous code, if an async placeholder was let through, we'd get something like num_computed_tokens=7601, and do this.

            if num_computed_tokens > prefill_rope:
                meta["ropes"] = [num_computed_tokens]

but num_computed_tokens is tokens, not positions. So it would overwrite that 10 -> 7601 because it was not considering that those positions should be shared, which would completely destroy outputs.

With the fix the on model side of things, letting an async placeholder instead looks like a decode, so it would overwrite 10 -> 11. However, as these are very close, it doesn't have a dramatic impact. This is now fixed too though.

There are some things that still look a little weird to me in the AR part of Bagel, but at least positions now match synchronous so results are fine, so I think this PR should be fine now

}
)
img2img_idx += 1
Expand Down
10 changes: 6 additions & 4 deletions vllm_omni/worker/gpu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,12 @@ def execute_model(
if finished_reqs and hasattr(self.model, "get_kv_transfer_metadata"):
for req_id, data in finished_reqs.items():
try:
req_idx = self.input_batch.req_id_to_index.get(req_id)
num_computed = (
int(self.input_batch.num_computed_tokens_cpu[req_idx]) if req_idx is not None else None
)
# NOTE: seq_len is the same as num_computed_tokens_cpu in current
# async scheduling, since both exclude async placeholders. We use
# seq_len since we control it, just in case upstream async scheduler
# semantics change in the future.
num_computed = data.get("seq_len")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly for safety since this is what is set when we mark the request for transfer


model_meta = self.model.get_kv_transfer_metadata(
req_id,
num_computed_tokens=num_computed,
Expand Down
Loading