diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 22dcf26d09..7abe8fc869 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -11,7 +11,6 @@ import pytest from vllm_omni.config.stage_config import ( - _EXECUTION_TYPE_TO_SCHEDULER, _PIPELINE_REGISTRY, ModelPipeline, PipelineConfig, @@ -20,6 +19,7 @@ StageExecutionType, StagePipelineConfig, StageType, + _resolve_scheduler, build_stage_runtime_overrides, register_pipeline, strip_parent_engine_args, @@ -768,25 +768,29 @@ def test_validate_no_stages(self): p = PipelineConfig(model_type="t", model_arch="A") assert any("no stages" in e.lower() for e in p.validate()) - def test_get_scheduler_cls(self): - p = PipelineConfig( - model_type="t", - model_arch="A", - stages=( - StagePipelineConfig(stage_id=0, model_stage="a", execution_type=StageExecutionType.LLM_AR), - StagePipelineConfig( - stage_id=1, model_stage="b", execution_type=StageExecutionType.LLM_GENERATION, input_sources=(0,) - ), - ), - ) - assert "OmniARScheduler" in p.get_scheduler_cls(0) - assert "OmniGenerationScheduler" in p.get_scheduler_cls(1) - -class TestExecutionTypeToScheduler: - def test_all_types_mapped(self): +class TestResolveScheduler: + def test_all_execution_types_handled(self): for et in StageExecutionType: - assert et in _EXECUTION_TYPE_TO_SCHEDULER + _resolve_scheduler(et) + + def test_ar_sync_when_false(self): + cls = _resolve_scheduler(StageExecutionType.LLM_AR, async_scheduling=False) + assert cls is not None + assert "Async" not in cls.__name__ + + def test_ar_async_when_true(self): + cls = _resolve_scheduler(StageExecutionType.LLM_AR, async_scheduling=True) + assert cls is not None + assert "Async" in cls.__name__ + + def test_generation(self): + cls = _resolve_scheduler(StageExecutionType.LLM_GENERATION) + assert cls is not None + assert "Generation" in cls.__name__ + + def test_diffusion_returns_none(self): + assert _resolve_scheduler(StageExecutionType.DIFFUSION) is None class TestPipelineRegistry: diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index 44cc83baea..d4e3366772 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -13,9 +13,10 @@ from typing import Any from vllm.logger import init_logger +from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler from vllm_omni.config.yaml_util import create_config, load_yaml_config, to_dict -from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler +from vllm_omni.core.sched.omni_ar_scheduler import OmniARAsyncScheduler, OmniARScheduler from vllm_omni.core.sched.omni_generation_scheduler import OmniGenerationScheduler _MODELS_DIR = Path(__file__).resolve().parent.parent / "model_executor" / "models" @@ -136,18 +137,27 @@ class StageExecutionType(str, Enum): DIFFUSION = "diffusion" -# Mapping class refs (not dotted-path strings) so module/class renames fail -# at import time instead of lazily at scheduler resolution. YAML overrides -# and downstream serialization still use the dotted-path string form; the -# conversion happens at the map lookup site via _scheduler_path(). -_EXECUTION_TYPE_TO_SCHEDULER: dict[StageExecutionType, type | None] = { - StageExecutionType.LLM_AR: OmniARScheduler, - StageExecutionType.LLM_GENERATION: OmniGenerationScheduler, - StageExecutionType.DIFFUSION: None, -} +def _resolve_scheduler( + execution_type: StageExecutionType, + async_scheduling: bool = True, +) -> type[VLLMScheduler] | None: + """Return the scheduler class for the given execution_type. + + NOTE: For AutoRegressive stages, we have two schedulers for sync / async + respectively, and decide which to used based on the value of async_scheduling. + For other execution types, async_scheduling is not used. + """ + if execution_type == StageExecutionType.LLM_AR: + if not async_scheduling: + return OmniARScheduler + return OmniARAsyncScheduler + if execution_type == StageExecutionType.LLM_GENERATION: + return OmniGenerationScheduler + # Diffusion currently returns None here. + return None -def _scheduler_path(cls: type | None) -> str | None: +def _scheduler_path(cls: type[VLLMScheduler] | None) -> str | None: """Return the dotted import path for a scheduler class (``None`` passes through).""" if cls is None: return None @@ -213,18 +223,6 @@ def get_stage(self, stage_id: int) -> StagePipelineConfig | None: return stage return None - def get_scheduler_cls(self, stage_id: int) -> str | None: - """Return the inferred scheduler class path for a stage. - - Returns ``None`` for DIFFUSION stages (no vLLM scheduler). Raises - ``ValueError`` if ``stage_id`` doesn't exist in this pipeline, and - ``KeyError`` if ``execution_type`` isn't in the scheduler map. - """ - stage = self.get_stage(stage_id) - if stage is None: - raise ValueError(f"Pipeline {self.model_type!r} has no stage with id {stage_id}") - return _scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER[stage.execution_type]) - def validate(self) -> list[str]: """Return list of topology errors (empty if valid).""" errors: list[str] = [] @@ -796,6 +794,12 @@ def merge_pipeline_deploy( stage_type, worker_type = _resolve_execution_mode(ps.execution_type) input_proc, next_stage_proc = _select_processor_funcs(ps, deploy.async_chunk) engine_args = _build_engine_args(ps, ds, pipeline, deploy, next_stage_proc) + sched_cls = _resolve_scheduler( + ps.execution_type, + engine_args.get("async_scheduling", True), + ) + if ps.execution_type == StageExecutionType.LLM_AR: + engine_args["async_scheduling"] = sched_cls is OmniARAsyncScheduler extras = _build_extras(ps, ds) runtime: dict[str, Any] = {"process": True} if ds is not None: @@ -811,7 +815,7 @@ def merge_pipeline_deploy( final_output=ps.final_output, final_output_type=ps.final_output_type, worker_type=worker_type, - scheduler_cls=_scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER.get(ps.execution_type)), + scheduler_cls=_scheduler_path(sched_cls), hf_config_name=ps.hf_config_name, is_comprehension=ps.owns_tokenizer, yaml_engine_args=engine_args, @@ -1280,6 +1284,23 @@ def _parse_pipeline_yaml(cls, path: Path, model_type: str) -> ModelPipeline: # YAMLs (worker_type, scheduler_cls, etc.) — read from both places. worker_type = stage_data.get("worker_type", None) or yaml_engine_args.pop("worker_type", None) scheduler_cls = stage_data.get("scheduler_cls", None) or yaml_engine_args.pop("scheduler_cls", None) + if scheduler_cls: + async_sched = yaml_engine_args.get("async_scheduling") + if async_sched is not None: + logger.warning( + "Stage %s: async_scheduling=%r and scheduler_cls=%r " + "should not be set together. scheduler_cls will take " + "precedence for which scheduler is used.", + stage_data.stage_id, + async_sched, + scheduler_cls, + ) + else: + logger.warning( + "Stage %s: scheduler_cls=%r is deprecated. Use async_scheduling instead.", + stage_data.stage_id, + scheduler_cls, + ) hf_config_name = stage_data.get("hf_config_name", None) or yaml_engine_args.pop("hf_config_name", None) model_stage = getattr(stage_data, "model_stage", None) or yaml_engine_args.pop("model_stage", None) diff --git a/vllm_omni/core/sched/__init__.py b/vllm_omni/core/sched/__init__.py index ecf18d07ac..0f80769d32 100644 --- a/vllm_omni/core/sched/__init__.py +++ b/vllm_omni/core/sched/__init__.py @@ -2,11 +2,12 @@ Scheduling components for vLLM-Omni. """ -from .omni_ar_scheduler import OmniARScheduler +from .omni_ar_scheduler import OmniARAsyncScheduler, OmniARScheduler from .omni_generation_scheduler import OmniGenerationScheduler from .output import OmniNewRequestData __all__ = [ + "OmniARAsyncScheduler", "OmniARScheduler", "OmniGenerationScheduler", "OmniNewRequestData", diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index ccdcf16435..2975848128 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -9,9 +9,9 @@ from vllm.distributed.kv_events import KVEventBatch from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger -from vllm.v1.core.sched.async_scheduler import AsyncScheduler as VLLMScheduler +from vllm.v1.core.sched.async_scheduler import AsyncScheduler as AsyncVLLMScheduler from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.core.sched.scheduler import Scheduler as SyncScheduler +from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler from vllm.v1.core.sched.utils import remove_all from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.metrics.perf import PerfStats @@ -41,12 +41,9 @@ def to_dict(self) -> dict[str, Any]: class OmniARScheduler(OmniSchedulerMixin, VLLMScheduler): - """ - OmniARScheduler: Scheduler for vLLM-Omni multimodal processing. - - This scheduler extends vLLM's scheduler to support multimodal and - non-autoregressive processing with additional fields and methods - specific to vLLM-Omni. + """Synchronous AutoRegressive scheduler for vLLM-Omni. This class is also + used as a base class for the OmniARAsyncScheduler and holds most of the + core scheduling logic. """ def __init__(self, *args, **kwargs): @@ -83,29 +80,9 @@ def __init__(self, *args, **kwargs): def _get_confirmed_num_computed_tokens(self, request: Request) -> int: """num_computed_tokens minus async placeholders (KV actually on GPU).""" + # Output placeholders are zero when async scheduling isn't used return request.num_computed_tokens - request.num_output_placeholders - def _update_request_with_output(self, request: Request, new_token_ids: list[int]) -> tuple[list[int], bool]: - """Append output tokens, then cache blocks up to the confirmed count - so KV transfer never sees blocks whose data has not been computed yet. - """ - if request.discard_latest_async_tokens: - request.discard_latest_async_tokens = False - return [], False - - status_before_update = request.status - - new_token_ids, stopped = SyncScheduler._update_request_with_output(self, request, new_token_ids) - - request.num_output_placeholders -= len(new_token_ids) - assert request.num_output_placeholders >= 0 - - if status_before_update == RequestStatus.RUNNING: - confirmed = self._get_confirmed_num_computed_tokens(request) - self.kv_cache_manager.cache_blocks(request, confirmed) - - return new_token_ids, stopped - def _get_kv_transfer_criteria(self) -> dict | None: # Note: vllm_config is available in Scheduler after super().__init__ if not hasattr(self, "vllm_config"): @@ -786,3 +763,7 @@ def get_finished_requests_needing_kv_transfer(self) -> dict[str, dict]: self.requests_needing_kv_transfer.clear() return requests + + +class OmniARAsyncScheduler(OmniARScheduler, AsyncVLLMScheduler): + """Asynchronous AutoRegressive scheduler.""" diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py index a805291dd7..4bc318e30a 100644 --- a/vllm_omni/model_executor/models/bagel/bagel.py +++ b/vllm_omni/model_executor/models/bagel/bagel.py @@ -562,6 +562,7 @@ def get_kv_transfer_metadata( *, num_computed_tokens: int | None = None, ) -> dict[str, Any] | None: + # NOTE: num_computed_tokens will not include async placeholders meta = self._ropes_metadata.pop(req_id, None) if meta is None: return None @@ -854,7 +855,7 @@ def _adjust_positions_for_img2img( { "ropes": [rope], "image_shape": [img_H, img_W], - "prefill_position_count": int(end - start), + "prefill_position_count": req_len, } ) img2img_idx += 1 diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 93c38b9b10..143a373fe8 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -313,10 +313,12 @@ def execute_model( if finished_reqs and hasattr(self.model, "get_kv_transfer_metadata"): for req_id, data in finished_reqs.items(): try: - req_idx = self.input_batch.req_id_to_index.get(req_id) - num_computed = ( - int(self.input_batch.num_computed_tokens_cpu[req_idx]) if req_idx is not None else None - ) + # NOTE: seq_len is the same as num_computed_tokens_cpu in current + # async scheduling, since both exclude async placeholders. We use + # seq_len since we control it, just in case upstream async scheduler + # semantics change in the future. + num_computed = data.get("seq_len") + model_meta = self.model.get_kv_transfer_metadata( req_id, num_computed_tokens=num_computed,