Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ openai-whisper>=20250625
imageio[ffmpeg]>=2.37.2
sox>=1.5.0
prettytable>=3.8.0
aenum==3.1.16
182 changes: 182 additions & 0 deletions tests/distributed/omni_connectors/test_chunk_transfer_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import threading
from collections import deque
from types import SimpleNamespace
from unittest.mock import MagicMock

import pytest
import torch
from vllm.v1.request import RequestStatus

from vllm_omni.distributed.omni_connectors.transfer_adapter.base import OmniTransferAdapterBase
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
)
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


class DummyWaitingQueue(list):
def prepend_requests(self, requests):
self[:0] = list(requests)

def add_request(self, request):
self.append(request)


def _req(req_id: str, status: RequestStatus, external_req_id: str | None = None):
return SimpleNamespace(
request_id=req_id,
external_req_id=external_req_id or req_id,
status=status,
prompt_token_ids=[],
num_computed_tokens=0,
additional_information=None,
)


@pytest.fixture
def build_adapter(monkeypatch):
def _build(*, stage_id: int = 1, model_mode: str = "ar", max_num_seqs: int = 2):
connector = MagicMock()
connector.stage_id = stage_id
connector.get.return_value = None
connector.put.return_value = (True, 1, {})

def _fake_base_init(self, config):
self.config = config
self._pending_load_reqs = {}
self._finished_load_reqs = set()
self._pending_save_reqs = {}
self._finished_save_reqs = set()
self.stop_event = threading.Event()
self.lock = threading.Lock()

monkeypatch.setattr(OmniTransferAdapterBase, "__init__", _fake_base_init)
monkeypatch.setattr(
OmniChunkTransferAdapter,
"create_connector",
classmethod(lambda cls, _model_config: connector),
)

model_config = SimpleNamespace(worker_type=model_mode)
scheduler_config = SimpleNamespace(max_num_seqs=max_num_seqs)
adapter = OmniChunkTransferAdapter(
SimpleNamespace(model_config=model_config, scheduler_config=scheduler_config)
)
return adapter, connector

return _build


@pytest.mark.parametrize(
("raw_cfg", "expected_name", "expected_extra"),
[
(None, "SharedMemoryConnector", {}),
(SimpleNamespace(name="YuanrongConnector", extra={"k": "v"}), "YuanrongConnector", {"k": "v"}),
],
)
def test_create_connector_config_parsing(monkeypatch, raw_cfg, expected_name, expected_extra):
captured = {}

def _fake_create(spec):
captured["spec"] = spec
return "ok"

monkeypatch.setattr(
"vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter"
".OmniConnectorFactory.create_connector",
_fake_create,
)

model_config = SimpleNamespace(stage_connector_config=raw_cfg) if raw_cfg is not None else SimpleNamespace()
connector = OmniChunkTransferAdapter.create_connector(model_config)

assert connector == "ok"
assert isinstance(captured["spec"], ConnectorSpec)
assert captured["spec"].name == expected_name
assert captured["spec"].extra == expected_extra


def test_load_poll(build_adapter):
adapter, connector = build_adapter(stage_id=2, model_mode="ar")
request = _req("req-1", RequestStatus.WAITING, external_req_id="external-1")

adapter.load_async(request)
payload = {"code_predictor_codes": [[1]], "hidden_states": torch.tensor([[2.0]]), "finished": True}
connector.get.return_value = (payload, 16)
adapter._poll_single_request("req-1")

connector.get.assert_called_once_with("1", "2", "external-1_1_0")
assert request.additional_information == payload
assert adapter.get_req_chunk["req-1"] == 1
assert "req-1" in adapter._finished_load_reqs
assert "req-1" in adapter.finished_requests
assert "req-1" not in adapter._pending_load_reqs


def test_save_async(build_adapter):
adapter, _ = build_adapter(stage_id=1)
request = SimpleNamespace(external_req_id="external-1")

adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": False}
adapter.save_async(pooling_output=None, request=request)
adapter.custom_process_next_stage_input_func = lambda **kwargs: {}
adapter.save_async(pooling_output=None, request=request)

assert adapter.put_req_chunk["external-1"] == 1
queued = adapter._pending_save_reqs["external-1"]
assert len(queued) == 1
assert queued[0]["put_key"] == "external-1_1_0"


def test_update_request_payload(build_adapter):
adapter, _ = build_adapter()

adapter._update_request_payload("ext", {"h": torch.tensor([[1.0]]), "codes": [1], "finished": False})
merged = adapter._update_request_payload("ext", {"h": torch.tensor([[2.0]]), "codes": [2], "finished": True})

assert torch.equal(merged["h"], torch.tensor([[1.0], [2.0]]))
assert merged["codes"] == [1, 2]
assert merged["finished"] is True


def test_process_and_restore_queues(build_adapter):
adapter, _ = build_adapter(stage_id=1, max_num_seqs=8)
waiting_req = _req("w1", RequestStatus.WAITING)
running_req = _req("r1", RequestStatus.RUNNING)
waiting_queue = DummyWaitingQueue([waiting_req])
running_queue = [running_req]

adapter.process_pending_chunks(waiting_queue, running_queue)
assert waiting_req.status == RequestStatus.WAITING_FOR_CHUNK
assert running_req.status == RequestStatus.WAITING_FOR_CHUNK
assert waiting_queue == []
assert running_queue == []

adapter.restore_queues(waiting_queue, running_queue)
assert waiting_queue == [waiting_req]
assert running_queue == [running_req]
assert adapter.waiting_for_chunk_waiting_requests == deque()
assert adapter.waiting_for_chunk_running_requests == deque()


def test_postprocess_scheduler_output(build_adapter):
adapter, _ = build_adapter()
adapter.requests_with_ready_chunks = {"new-ready", "cached-ready", "leftover"}

scheduler_output = SimpleNamespace(
scheduled_new_reqs=[SimpleNamespace(req_id="new-ready")],
scheduled_cached_reqs=SimpleNamespace(req_ids=["cached-ready", "missing"]),
)
requests = {"cached-ready": SimpleNamespace(additional_information={"k": "v"})}

adapter.postprocess_scheduler_output(scheduler_output, requests)

cached_info = scheduler_output.scheduled_cached_reqs.additional_information
assert cached_info["cached-ready"] == {"k": "v"}
assert cached_info["missing"] is None
assert adapter.requests_with_ready_chunks == {"leftover"}
2 changes: 2 additions & 0 deletions vllm_omni/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class OmniModelConfig(ModelConfig):
(default: "thinker")
model_arch: Model architecture name
(default: "Qwen2_5OmniForConditionalGeneration")
worker_type: Model Type, e.g., "ar" or "generation"
engine_output_type: Optional output type specification for the engine.
Used to route outputs to appropriate processors (e.g., "image",
"audio", "latents"). If None, output type is inferred.
Expand All @@ -63,6 +64,7 @@ class OmniModelConfig(ModelConfig):
async_chunk: bool = False
model_stage: str = "thinker"
model_arch: str = "Qwen2_5OmniForConditionalGeneration"
worker_type: str | None = None
engine_output_type: str | None = None
hf_config_name: str | None = None
custom_process_next_stage_input_func: str | None = None
Expand Down
53 changes: 24 additions & 29 deletions vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import importlib
from collections import defaultdict
from dataclasses import asdict, dataclass
from time import time
Expand All @@ -20,9 +19,9 @@
from vllm.v1.spec_decode.metrics import SpecDecodingStats

from vllm_omni.core.sched.output import OmniSchedulerOutput
from vllm_omni.distributed.omni_connectors.adapter import get_chunk, put_chunk
from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -65,24 +64,9 @@ def __init__(self, *args, **kwargs):
# Track requests that have already triggered prefill transfer to avoid duplicates
self.transfer_triggered_requests: set[str] = set()
model_config = self.vllm_config.model_config
self.omni_connector = None
if model_config.async_chunk:
connector_config = model_config.stage_connector_config
connector_specs = ConnectorSpec(
name=connector_config.get("name", "SharedMemoryConnector"),
extra=connector_config.get("extra", {}),
)
self.omni_connector = OmniConnectorFactory.create_connector(connector_specs)

custom_process_next_stage_input_func = getattr(
self.vllm_config.model_config, "custom_process_next_stage_input_func", None
)
if custom_process_next_stage_input_func:
module_path, func_name = custom_process_next_stage_input_func.rsplit(".", 1)
module = importlib.import_module(module_path)
self.custom_process_next_stage_input_func = getattr(module, func_name)

self.stage_id = getattr(self.vllm_config.model_config, "stage_id", None)
self.chunk_transfer_adapter = None
if getattr(model_config, "async_chunk", False):
self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config)

def _get_kv_transfer_criteria(self) -> dict | None:
# Note: vllm_config is available in Scheduler after super().__init__
Expand Down Expand Up @@ -152,7 +136,15 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
return False

def schedule(self) -> SchedulerOutput: # type: ignore[override]
scheduler_output = super().schedule()
if self.chunk_transfer_adapter:
self.chunk_transfer_adapter.process_pending_chunks(self.waiting, self.running)

try:
scheduler_output = super().schedule()
finally:
if self.chunk_transfer_adapter:
# Add request waiting for chunk to the waiting and running queue
self.chunk_transfer_adapter.restore_queues(self.waiting, self.running)
try:
# Late import to avoid circulars in some launch modes
from .output import OmniNewRequestData
Expand Down Expand Up @@ -181,9 +173,8 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override]
new_list.append(omni_nr)

scheduler_output.scheduled_new_reqs = new_list # type: ignore[assignment]
if self.omni_connector is not None:
get_chunk(self.omni_connector, scheduler_output)

if self.chunk_transfer_adapter:
self.chunk_transfer_adapter.postprocess_scheduler_output(scheduler_output, self.requests)
# Add information about requests needing KV cache transfer
finished_reqs = self.get_finished_requests_needing_kv_transfer()
except Exception:
Expand Down Expand Up @@ -312,6 +303,11 @@ def update_from_output(
kv_transfer_params = self._free_request(request)
if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request)
elif status_before_stop == RequestStatus.WAITING_FOR_CHUNK:
# In async chunk mode, request may be in either queue.
# Remove from both to avoid stale queue entries.
stopped_running_reqs.add(request)
stopped_preempted_reqs.add(request)
else:
stopped_preempted_reqs.add(request)

Expand Down Expand Up @@ -355,9 +351,8 @@ def update_from_output(
num_nans_in_logits=request.num_nans_in_logits,
)
)
if self.omni_connector is not None:
custom_process_next_stage_input_func = self.custom_process_next_stage_input_func
put_chunk(self.omni_connector, pooler_output, request, custom_process_next_stage_input_func)
if self.chunk_transfer_adapter is not None:
self.chunk_transfer_adapter.save_async(pooler_output, request)
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
Expand Down
Loading