Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
690 changes: 690 additions & 0 deletions tests/core/sched/test_chunk_scheduling_coordinator.py

Large diffs are not rendered by default.

1,419 changes: 1,419 additions & 0 deletions tests/worker/test_omni_connector_mixin.py

Large diffs are not rendered by default.

380 changes: 380 additions & 0 deletions vllm_omni/core/sched/omni_scheduling_coordinator.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion vllm_omni/diffusion/worker/diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@
from vllm_omni.diffusion.worker.utils import DiffusionRequestState, RunnerOutput
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.platforms import current_omni_platform
from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin

logger = init_logger(__name__)


class DiffusionModelRunner:
class DiffusionModelRunner(OmniConnectorModelRunnerMixin):
"""
Model runner that handles model loading and execution for diffusion models.

Expand Down
28 changes: 28 additions & 0 deletions vllm_omni/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,33 @@
from vllm_omni.inputs.data import OmniPromptType


@dataclass
class OmniConnectorOutput:
"""Communication results from Model Runner to Scheduler.

Carries transfer readiness signals so the Scheduler can make scheduling
decisions without ever calling connector.put()/get() directly.

Attributes:
chunk_ready_req_ids: Request IDs with newly arrived chunks this cycle.
chunk_finished_req_ids: Request IDs whose final chunk has arrived.
request_metadata: Lightweight scheduling metadata keyed by request ID
(e.g. next_stage_prompt_len, code_predictor_codes, left_context_size).
Full payloads are owned by the Model Runner's local cache.
kv_sent_req_ids: Request IDs whose KV cache was successfully sent.
stage_recv_req_ids: Request IDs that received batch stage inputs.
has_pending_kv_work: True if the mixin has pending, active, or
completed KV transfers that the scheduler should account for.
"""

chunk_ready_req_ids: set[str] = field(default_factory=set)
chunk_finished_req_ids: set[str] = field(default_factory=set)
request_metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
kv_sent_req_ids: list[str] = field(default_factory=list)
stage_recv_req_ids: set[str] = field(default_factory=set)
has_pending_kv_work: bool = False


class OmniModelRunnerOutput(ModelRunnerOutput):
"""Model runner output for omni models.

Expand All @@ -24,6 +51,7 @@ class OmniModelRunnerOutput(ModelRunnerOutput):
# IDs of requests whose KV cache has been extracted from GPU/NPU to CPU.
# The Scheduler can safely free the block tables for these requests.
kv_extracted_req_ids: list[str] | None = None
omni_connector_output: OmniConnectorOutput | None = None


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion vllm_omni/worker/gpu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin

logger = init_logger(__name__)

Expand All @@ -60,7 +61,7 @@ class ExecuteModelState(NamedTuple):
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None


class GPUARModelRunner(OmniGPUModelRunner):
class GPUARModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin):
"""Autoregressive GPU model runner that returns hidden states per request.

Follows the v0.12 two-phase execute/sample flow from GPUModelRunner, and
Expand Down
3 changes: 2 additions & 1 deletion vllm_omni/worker/gpu_generation_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin

logger = logging.getLogger(__name__)


class GPUGenerationModelRunner(OmniGPUModelRunner):
class GPUGenerationModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin):
"""Generation model runner for vLLM-Omni (non-autoregressive).

- Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue.
Expand Down
Loading
Loading