diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/core/sched/__init__.py b/tests/core/sched/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/core/sched/test_omni_ar_scheduler_streaming.py b/tests/core/sched/test_omni_ar_scheduler_streaming.py new file mode 100644 index 00000000000..dd47de2fd19 --- /dev/null +++ b/tests/core/sched/test_omni_ar_scheduler_streaming.py @@ -0,0 +1,205 @@ +"""Unit tests for OmniARScheduler streaming input handling. + +Tests the key behavioral changes for streaming input support: +1. finish_reason is captured BEFORE _handle_stopped_request (which may reset status) +2. _free_request is only called when the request truly finishes +3. Output is always emitted when stopped (even without new tokens) +""" + +from collections import deque +from unittest.mock import MagicMock, Mock, PropertyMock, patch + +import pytest +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import EngineCoreOutput, FinishReason +from vllm.v1.request import Request, RequestStatus + +from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler + + +def _make_mock_request( + request_id: str = "req-0", + resumable: bool = False, + status: RequestStatus = RequestStatus.RUNNING, +) -> Mock: + """Create a mock Request for testing.""" + req = MagicMock(spec=Request) + req.request_id = request_id + req.client_index = 0 + req.status = status + req.resumable = resumable + req.output_token_ids = [10, 11] + req.num_tokens = 5 + req.num_computed_tokens = 5 + req.num_output_placeholders = 0 + req.sampling_params = SamplingParams(max_tokens=10) + req.pooling_params = None + req.stop_reason = None + req.trace_headers = None + req.num_cached_tokens = 0 + req.num_nans_in_logits = None + req.streaming_queue = deque() if resumable else None + req.take_events.return_value = None + + # Simulate get_finished_reason returning LENGTH before reset + req.get_finished_reason.return_value = FinishReason.LENGTH + return req + + +class TestFinishReasonCapturedBeforeHandleStop: + """Test that finish_reason is captured before _handle_stopped_request.""" + + def test_finish_reason_captured_for_resumable_request(self): + """For resumable requests, finish_reason must be captured before + _handle_stopped_request which may reset the status. + + The scheduler captures finish_reason before calling + _handle_stopped_request. For resumable requests, + _handle_stopped_request returns False (request continues) + and may reset the status to WAITING. If we queried + get_finished_reason after that, we'd get None. + """ + request = _make_mock_request("req-resume", resumable=True) + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + + # Capture finish_reason BEFORE handle + finish_reason = request.get_finished_reason() + assert finish_reason == FinishReason.LENGTH + + # After _handle_stopped_request for resumable, status would be reset + # Simulate: base scheduler sets status to WAITING_FOR_STREAMING_REQ + request.status = RequestStatus.WAITING_FOR_STREAMING_REQ + # Now get_finished_reason would return STOP (or different value) + request.get_finished_reason.return_value = FinishReason.STOP + + # The captured value should still be LENGTH + assert finish_reason == FinishReason.LENGTH + + def test_finish_reason_captured_for_non_resumable(self): + """For non-resumable, behavior is the same - capture before handle.""" + request = _make_mock_request("req-normal", resumable=False) + request.status = RequestStatus.FINISHED_STOPPED + request.get_finished_reason.return_value = FinishReason.STOP + + finish_reason = request.get_finished_reason() + assert finish_reason == FinishReason.STOP + + +class TestFreeRequestConditional: + """Test that _free_request is only called when request truly finishes.""" + + def test_free_called_for_non_resumable(self): + """Non-resumable stopped request: _handle_stopped_request returns True, + _free_request IS called.""" + scheduler = MagicMock(spec=OmniARScheduler) + + # Simulate _handle_stopped_request returning True (finished) + scheduler._handle_stopped_request.return_value = True + scheduler._free_request.return_value = {"kv": "params"} + + request = _make_mock_request("req-done", resumable=False) + + # Execute the logic from _update_outputs + finished = scheduler._handle_stopped_request(request) + kv_transfer_params = None + if finished: + kv_transfer_params = scheduler._free_request(request) + + assert finished is True + assert kv_transfer_params == {"kv": "params"} + scheduler._free_request.assert_called_once_with(request) + + def test_free_not_called_for_resumable(self): + """Resumable stopped request: _handle_stopped_request returns False, + _free_request is NOT called (KV blocks are preserved).""" + scheduler = MagicMock(spec=OmniARScheduler) + + # Simulate _handle_stopped_request returning False (resumable, continues) + scheduler._handle_stopped_request.return_value = False + + request = _make_mock_request("req-continue", resumable=True) + + finished = scheduler._handle_stopped_request(request) + kv_transfer_params = None + if finished: + kv_transfer_params = scheduler._free_request(request) + + assert finished is False + assert kv_transfer_params is None + scheduler._free_request.assert_not_called() + + +class TestOutputEmittedWhenStopped: + """Test that output is always emitted when stopped, even without new tokens.""" + + def test_output_emitted_on_stopped_without_tokens(self): + """The condition `new_token_ids or pooler_output is not None + or kv_transfer_params or stopped` should emit output even when + new_token_ids is empty, as long as stopped is True. + + This is critical for streaming input: the last sub-request + may produce no new tokens but must still signal finish. + """ + new_token_ids = [] + pooler_output = None + kv_transfer_params = None + stopped = True + + should_emit = bool(new_token_ids or pooler_output is not None or kv_transfer_params or stopped) + assert should_emit is True + + def test_no_output_when_nothing_to_emit(self): + """Without tokens, pooler, kv_params, or stop, nothing emitted.""" + new_token_ids = [] + pooler_output = None + kv_transfer_params = None + stopped = False + + should_emit = bool(new_token_ids or pooler_output is not None or kv_transfer_params or stopped) + assert should_emit is False + + def test_output_emitted_with_tokens(self): + """Normal case: tokens present means output emitted.""" + new_token_ids = [42] + pooler_output = None + kv_transfer_params = None + stopped = False + + should_emit = bool(new_token_ids or pooler_output is not None or kv_transfer_params or stopped) + assert should_emit is True + + +class TestBaseHandleStoppedRequest: + """Test the base _handle_stopped_request behavior (inherited from VLLMScheduler).""" + + def test_non_resumable_returns_true(self): + """Non-resumable request always returns True (finished).""" + request = _make_mock_request("req-nr", resumable=False) + # The base _handle_stopped_request checks request.resumable + # For non-resumable: return True + assert not request.resumable + + def test_resumable_with_empty_queue_waits(self): + """Resumable request with empty streaming_queue enters waiting state.""" + request = _make_mock_request("req-wait", resumable=True) + request.streaming_queue = deque() # empty queue + assert request.resumable + assert len(request.streaming_queue) == 0 + + def test_resumable_with_none_in_queue_finishes(self): + """Resumable request with None sentinel in queue means finished.""" + request = _make_mock_request("req-fin", resumable=True) + request.streaming_queue = deque([None]) + assert request.resumable + # Popping None from the queue signals the request is done + update = request.streaming_queue.popleft() + assert update is None + + def test_resumable_with_update_in_queue_continues(self): + """Resumable request with a real update in queue should continue.""" + request = _make_mock_request("req-cont", resumable=True) + mock_update = MagicMock() + request.streaming_queue = deque([mock_update]) + assert request.resumable + update = request.streaming_queue.popleft() + assert update is mock_update diff --git a/tests/engine/__init__.py b/tests/engine/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/engine/test_output_processor_streaming.py b/tests/engine/test_output_processor_streaming.py new file mode 100644 index 00000000000..9e8c121a38d --- /dev/null +++ b/tests/engine/test_output_processor_streaming.py @@ -0,0 +1,298 @@ +"""Unit tests for MultimodalOutputProcessor streaming input handling.""" + +from collections import deque +from unittest.mock import MagicMock, Mock, patch + +import pytest +import torch +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine.output_processor import ( + OutputProcessorOutput, + RequestOutputCollector, + StreamingUpdate, +) + +from vllm_omni.engine.output_processor import ( + MultimodalOutputProcessor, + OmniRequestState, +) + + +def _make_tokenizer(): + """Create a minimal mock tokenizer.""" + tokenizer = MagicMock() + tokenizer.eos_token_id = 2 + tokenizer.convert_ids_to_tokens.return_value = [""] + return tokenizer + + +def _make_engine_core_request( + request_id: str = "req-0", + prompt_token_ids: list[int] | None = None, + resumable: bool = False, +) -> EngineCoreRequest: + """Create a minimal EngineCoreRequest.""" + req = MagicMock(spec=EngineCoreRequest) + req.request_id = request_id + req.external_req_id = request_id + req.prompt_token_ids = prompt_token_ids or [1, 2, 3] + req.prompt_embeds = None + req.mm_features = None + req.sampling_params = SamplingParams(max_tokens=10) + req.pooling_params = None + req.eos_token_id = 2 + req.arrival_time = 1000.0 + req.lora_request = None + req.cache_salt = None + req.data_parallel_rank = None + req.resumable = resumable + req.trace_headers = None + return req + + +def _make_engine_core_output( + request_id: str = "req-0", + new_token_ids: list[int] | None = None, + finish_reason: FinishReason | None = None, + pooling_output: torch.Tensor | None = None, +) -> EngineCoreOutput: + """Create a minimal EngineCoreOutput.""" + eco = EngineCoreOutput( + request_id=request_id, + new_token_ids=new_token_ids or [10], + finish_reason=finish_reason, + ) + eco.pooling_output = pooling_output + eco.stop_reason = None + eco.kv_transfer_params = None + eco.num_cached_tokens = 0 + eco.routed_experts = None + return eco + + +@pytest.fixture +def processor(): + """Create a MultimodalOutputProcessor with a mock tokenizer.""" + tokenizer = _make_tokenizer() + proc = MultimodalOutputProcessor( + tokenizer=tokenizer, + log_stats=False, + engine_core_output_type=None, + ) + return proc + + +class TestAddRequestStreamingSession: + """Tests for add_request handling of streaming input sessions.""" + + def test_add_request_new_creates_state(self, processor): + """First add_request creates OmniRequestState.""" + request = _make_engine_core_request("req-new", resumable=True) + processor.add_request(request, prompt="Hello") + + assert "req-new" in processor.request_states + state = processor.request_states["req-new"] + assert isinstance(state, OmniRequestState) + assert state.streaming_input is True + assert state.input_chunk_queue is not None + + def test_add_request_duplicate_routes_to_streaming_update(self, processor): + """Second add_request with same ID routes to _update_streaming_request_state.""" + request1 = _make_engine_core_request("req-dup", resumable=True) + processor.add_request(request1, prompt="Hello") + + state_before = processor.request_states["req-dup"] + assert isinstance(state_before, OmniRequestState) + + # Second call with same ID should update, not raise + request2 = _make_engine_core_request("req-dup", resumable=True) + request2.prompt_token_ids = [1, 2, 3, 4, 5] + processor.add_request(request2, prompt=" world") + + # State should still exist (not replaced) + assert "req-dup" in processor.request_states + + def test_add_request_non_resumable_creates_normal_state(self, processor): + """Non-resumable request creates normal state without streaming fields.""" + request = _make_engine_core_request("req-normal", resumable=False) + processor.add_request(request, prompt="Hello") + + state = processor.request_states["req-normal"] + assert isinstance(state, OmniRequestState) + assert state.streaming_input is False + assert state.input_chunk_queue is None + + +class TestFinishRequest: + """Tests for _finish_request omni-specific cleanup.""" + + def test_finish_cleans_mm_state(self, processor): + """_finish_request clears mm_accumulated and mm_type.""" + request = _make_engine_core_request("req-finish") + processor.add_request(request, prompt="Hello") + + state = processor.request_states["req-finish"] + state.mm_accumulated = {"audio": torch.randn(1, 4)} + state.mm_type = "audio" + + processor._finish_request(state) + + assert state.mm_accumulated is None + assert state.mm_type is None + assert "req-finish" not in processor.request_states + + +class TestProcessOutputsStreaming: + """Tests for process_outputs handling of streaming input requests.""" + + def test_streaming_output_marked_not_finished(self, processor): + """When streaming_input is True, output.finished should be False.""" + request = _make_engine_core_request("req-stream", resumable=True) + processor.add_request(request, prompt="Hello") + + state = processor.request_states["req-stream"] + assert state.streaming_input is True + + eco = _make_engine_core_output( + "req-stream", + new_token_ids=[10], + finish_reason=FinishReason.LENGTH, + ) + + result = processor.process_outputs([eco]) + # Since we have no queue, outputs go to request_outputs list. + # The streaming request should remain in request_states (not freed). + assert "req-stream" in processor.request_states + + def test_non_streaming_output_freed(self, processor): + """Non-streaming finished request should be freed.""" + request = _make_engine_core_request("req-normal", resumable=False) + processor.add_request(request, prompt="Hello") + + eco = _make_engine_core_output( + "req-normal", + new_token_ids=[10], + finish_reason=FinishReason.LENGTH, + ) + + result = processor.process_outputs([eco]) + assert "req-normal" not in processor.request_states + + def test_streaming_with_queued_update_applies_update(self, processor): + """When streaming request finishes and has queued update, it's applied.""" + request = _make_engine_core_request("req-queued", resumable=True) + processor.add_request(request, prompt="Hello") + + state = processor.request_states["req-queued"] + assert state.streaming_input is True + + update = StreamingUpdate( + prompt=" world", + prompt_token_ids=[4, 5, 6], + arrival_time=2000.0, + ) + state.input_chunk_queue.append(update) + + eco = _make_engine_core_output( + "req-queued", + new_token_ids=[10], + finish_reason=FinishReason.LENGTH, + ) + + processor.process_outputs([eco]) + + # Request should still be in states (streaming continues) + assert "req-queued" in processor.request_states + # The update should have been applied + assert state.is_prefilling is True + + def test_streaming_empty_queue_clears_queue(self, processor): + """When streaming request finishes and queue is empty, queue is set to None.""" + request = _make_engine_core_request("req-empty-q", resumable=True) + processor.add_request(request, prompt="Hello") + + state = processor.request_states["req-empty-q"] + assert state.streaming_input is True + # Queue is empty (no pending updates) + assert len(state.input_chunk_queue) == 0 + + eco = _make_engine_core_output( + "req-empty-q", + new_token_ids=[10], + finish_reason=FinishReason.LENGTH, + ) + + processor.process_outputs([eco]) + + # Queue should now be None (waiting for next streaming input) + assert state.input_chunk_queue is None + # Request should still exist + assert "req-empty-q" in processor.request_states + + +class TestMultimodalAccumulation: + """Tests for multimodal tensor accumulation in OmniRequestState.""" + + def test_add_multimodal_tensor_single(self): + """Single tensor addition to empty state.""" + state = OmniRequestState.__new__(OmniRequestState) + state.mm_type = None + state.mm_accumulated = None + + tensor = torch.randn(2, 3) + state.add_multimodal_tensor(tensor, "audio") + + assert state.mm_type == "audio" + assert "audio" in state.mm_accumulated + assert torch.equal(state.mm_accumulated["audio"], tensor) + + def test_add_multimodal_tensor_accumulate(self): + """Multiple tensors accumulated into a list for deferred concat.""" + state = OmniRequestState.__new__(OmniRequestState) + state.mm_type = None + state.mm_accumulated = None + + t1 = torch.randn(2, 3) + t2 = torch.randn(2, 3) + state.add_multimodal_tensor(t1, "audio") + state.add_multimodal_tensor(t2, "audio") + + assert isinstance(state.mm_accumulated["audio"], list) + assert len(state.mm_accumulated["audio"]) == 2 + + def test_consolidate_tensors(self): + """Consolidation concatenates accumulated tensor lists.""" + state = OmniRequestState.__new__(OmniRequestState) + state.mm_type = "audio" + state.mm_accumulated = { + "audio": [torch.randn(1, 4), torch.randn(1, 4)] + } + + state._consolidate_multimodal_tensors() + + assert isinstance(state.mm_accumulated["audio"], torch.Tensor) + assert state.mm_accumulated["audio"].shape == (1, 8) # cat on dim=-1 for audio + + def test_add_multimodal_tensor_none_is_noop(self): + """Adding None payload should not change state.""" + state = OmniRequestState.__new__(OmniRequestState) + state.mm_type = None + state.mm_accumulated = None + + state.add_multimodal_tensor(None, "audio") + + assert state.mm_accumulated is None + + def test_add_multimodal_dict_payload(self): + """Dict payload is normalized correctly.""" + state = OmniRequestState.__new__(OmniRequestState) + state.mm_type = None + state.mm_accumulated = None + + payload = {"model_outputs": torch.randn(2, 4)} + state.add_multimodal_tensor(payload, "audio") + + # "model_outputs" should be renamed to the mm_type ("audio") + assert "audio" in state.mm_accumulated + assert state.mm_accumulated["audio"].shape == (2, 4) diff --git a/tests/entrypoints/__init__.py b/tests/entrypoints/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/entrypoints/test_streaming_input_e2e.py b/tests/entrypoints/test_streaming_input_e2e.py new file mode 100644 index 00000000000..3500b70395f --- /dev/null +++ b/tests/entrypoints/test_streaming_input_e2e.py @@ -0,0 +1,238 @@ +"""End-to-end showcase test for streaming input support. + +Demonstrates how the streaming input API works with AsyncOmniLLM: + + 1. An async generator yields StreamingInput chunks over time + 2. Each chunk is processed by the engine (prompt is extended, KV cache preserved) + 3. Each sub-request produces intermediate outputs (finished=False) + 4. When the generator exhausts, a final output is emitted (finished=True) + +This is the pattern used by multi-turn speech-to-speech pipelines, where +perception tokens, LLM tokens, and TTS tokens are streamed between stages. + +Usage from the demo notebook: + + async def input_generator() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt="Once upon a time") + yield StreamingInput(prompt=" in a magical forest") + yield StreamingInput(prompt=" there lived a dragon who") + + async for output in engine.generate( + input_generator(), + sampling_params=sampling_params, + request_id="my_session", + ): + print(output.outputs[0].text, output.finished) +""" + +import asyncio +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock + +import pytest +from vllm.outputs import RequestOutput +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput +from vllm.v1.engine.output_processor import RequestOutputCollector + + +def _make_output(request_id: str, finished: bool) -> RequestOutput: + return RequestOutput( + request_id=request_id, + prompt="test", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[], + finished=finished, + ) + + +@pytest.fixture +def mock_llm(): + """Create a mock AsyncLLM with the real generate() method bound.""" + llm = MagicMock(spec=AsyncLLM) + llm.vllm_config = MagicMock() + llm.vllm_config.cache_config.kv_sharing_fast_prefill = False + llm.model_config = MagicMock() + llm.model_config.max_model_len = 2048 + llm.log_requests = False + llm.errored = False + llm._pause_cond = asyncio.Condition() + llm._paused = False + llm._run_output_handler = MagicMock() + llm.abort = AsyncMock() + + # Bind the real generate() method from AsyncLLM + llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM) + return llm + + +# ───────────────────────────────────────────────────────────────────────────── +# Showcase 1: Basic streaming input with async generator +# ───────────────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_streaming_input_basic_flow(mock_llm): + """Showcase: feed 2 prompt chunks via async generator, get 3 outputs. + + Flow: + input_generator yields "Hello" → engine processes → output (not finished) + input_generator yields " world" → engine processes → output (not finished) + input_generator exhausts → engine finalizes → output (finished=True) + + This is the fundamental streaming input pattern. Each yield extends + the KV cache context without recomputing previous tokens. + """ + request_id = "streaming_showcase" + sampling_params = SamplingParams( + max_tokens=5, output_kind=RequestOutputKind.DELTA + ) + + queue = RequestOutputCollector(RequestOutputKind.DELTA, request_id) + inputs_received: list[str] = [] + + async def mock_add_request(req_id, prompt, params, *args, **kwargs): + if isinstance(prompt, AsyncGenerator): + async def handle_stream(): + async for chunk in prompt: + inputs_received.append(chunk.prompt) + queue.put(_make_output(req_id, finished=False)) + await asyncio.sleep(0.01) + # Final output when generator exhausts + queue.put(_make_output(req_id, finished=True)) + + asyncio.create_task(handle_stream()) + return queue + return queue + + mock_llm.add_request = mock_add_request + + # ── The user-facing API ── + async def input_generator() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt="Hello", sampling_params=sampling_params) + yield StreamingInput(prompt=" world", sampling_params=sampling_params) + + outputs: list[RequestOutput] = [] + async for output in mock_llm.generate( + input_generator(), sampling_params, request_id + ): + outputs.append(output) + + # 2 intermediate + 1 final + assert len(outputs) == 3 + assert outputs[0].finished is False + assert outputs[1].finished is False + assert outputs[2].finished is True + assert inputs_received == ["Hello", " world"] + + +# ───────────────────────────────────────────────────────────────────────────── +# Showcase 2: Synchronized injection (explicit autoregression) +# ───────────────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_streaming_input_synchronized_injection(mock_llm): + """Showcase: synchronized input injection driven by output reception. + + This pattern is used for explicit autoregression loops where the + caller feeds one token at a time and waits for the model's output + before injecting the next token: + + inject "A" → decode → get output → inject "B" → decode → get output → done + + The sync_queue bridges the output consumer and input producer, + ensuring strict alternation between input injection and output + reception. + """ + request_id = "sync_showcase" + sampling_params = SamplingParams( + max_tokens=1, output_kind=RequestOutputKind.DELTA + ) + + queue = RequestOutputCollector(RequestOutputKind.DELTA, request_id) + # Synchronization queue: output consumer signals input producer + sync_queue: asyncio.Queue[int | None] = asyncio.Queue() + inject_tokens = ["token_A", "token_B", "token_C"] + inputs_received: list[str] = [] + + async def mock_add_request(req_id, prompt, params, *args, **kwargs): + if isinstance(prompt, AsyncGenerator): + async def handle_stream(): + async for chunk in prompt: + inputs_received.append(chunk.prompt) + queue.put(_make_output(req_id, finished=False)) + await asyncio.sleep(0.01) + queue.put(_make_output(req_id, finished=True)) + + asyncio.create_task(handle_stream()) + return queue + return queue + + mock_llm.add_request = mock_add_request + + # ── Input producer: waits for signal before yielding next chunk ── + async def synchronized_input_generator() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt="initial_prompt", sampling_params=sampling_params) + for token in inject_tokens: + signal = await sync_queue.get() + if signal is None: + break + yield StreamingInput(prompt=token, sampling_params=sampling_params) + + # ── Output consumer: signals input producer after each output ── + outputs: list[RequestOutput] = [] + step = 0 + async for output in mock_llm.generate( + synchronized_input_generator(), sampling_params, request_id + ): + outputs.append(output) + if not output.finished and step < len(inject_tokens): + await sync_queue.put(step) + step += 1 + + # initial_prompt + 3 injected tokens = 4 intermediate + 1 final + assert len(outputs) == 5 + assert all(not o.finished for o in outputs[:-1]) + assert outputs[-1].finished is True + assert inputs_received == ["initial_prompt", "token_A", "token_B", "token_C"] + + +# ───────────────────────────────────────────────────────────────────────────── +# Showcase 3: Normal (non-streaming) generate still works +# ───────────────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_normal_generate_unaffected(mock_llm): + """Non-streaming generate() continues to work as before. + + Passing a plain string prompt (not an AsyncGenerator) goes through + the normal code path. This ensures the streaming input changes + don't break existing behavior. + """ + request_id = "normal_request" + sampling_params = SamplingParams(max_tokens=10) + + queue = RequestOutputCollector(RequestOutputKind.FINAL_ONLY, request_id) + + async def feed_outputs(): + queue.put(_make_output(request_id, finished=False)) + await asyncio.sleep(0.05) + queue.put(_make_output(request_id, finished=True)) + + asyncio.create_task(feed_outputs()) + + async def mock_add_request(*args, **kwargs): + return queue + + mock_llm.add_request = mock_add_request + + outputs: list[RequestOutput] = [] + async for output in mock_llm.generate( + prompt="Tell me about Paris", + sampling_params=sampling_params, + request_id=request_id, + ): + outputs.append(output) + + assert len(outputs) == 2 + assert outputs[-1].finished is True diff --git a/tests/worker/test_omni_gpu_model_runner_streaming.py b/tests/worker/test_omni_gpu_model_runner_streaming.py new file mode 100644 index 00000000000..4c1fc266569 --- /dev/null +++ b/tests/worker/test_omni_gpu_model_runner_streaming.py @@ -0,0 +1,288 @@ +"""Unit tests for OmniGPUModelRunner._update_streaming_request.""" + +from unittest.mock import Mock + +import pytest +import torch + +from vllm.sampling_params import SamplingParams +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + +pytestmark = pytest.mark.cpu_test + + +@pytest.fixture +def mock_runner_with_input_batch(): + """Create a mock OmniGPUModelRunner with a real InputBatch.""" + from vllm.v1.worker.gpu_input_batch import InputBatch + + runner = Mock(spec=OmniGPUModelRunner) + runner.uses_mrope = False + runner.requests = {} + runner.max_num_reqs = 10 + runner.max_model_len = 1024 + + runner.input_batch = InputBatch( + max_num_reqs=10, + max_model_len=1024, + max_num_batched_tokens=1024, + device="cpu", + pin_memory=False, + vocab_size=32000, + block_sizes=[16], + kernel_block_sizes=[16], + is_spec_decode=False, + logitsprocs=None, + is_pooling_model=False, + ) + return runner + + +def test_update_streaming_request_basic(mock_runner_with_input_batch): + """Test that streaming request state is updated correctly. + + Validates: + 1. The request is removed from InputBatch (avoids duplication) + 2. Request state fields (prompt_token_ids, sampling_params, block_ids, + num_computed_tokens) are updated + 3. output_token_ids is cleared (intermediate outputs moved to prompt) + 4. prompt_embeds is decoded from PromptEmbedsPayload if present + """ + runner = mock_runner_with_input_batch + req_id = "streaming_req_0" + + initial_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=[1, 2, 3], + mm_features=[], + sampling_params=SamplingParams(temperature=0.5), + pooling_params=None, + generator=None, + block_ids=([0],), + num_computed_tokens=3, + output_token_ids=[10, 11], + ) + runner.requests[req_id] = initial_state + runner.input_batch.add_request(initial_state) + assert req_id in runner.input_batch.req_id_to_index + + new_req_data = Mock() + new_req_data.prompt_token_ids = [1, 2, 3, 10, 4, 5] + new_req_data.mm_features = [] + new_req_data.prompt_embeds = None + new_req_data.sampling_params = SamplingParams(temperature=0.8, max_tokens=50) + new_req_data.pooling_params = None + new_req_data.block_ids = ([0, 1],) + new_req_data.num_computed_tokens = 4 + new_req_data.additional_information = None + + updated = OmniGPUModelRunner._update_streaming_request( + runner, req_id, new_req_data + ) + + assert updated.prompt_token_ids == [1, 2, 3, 10, 4, 5] + assert updated.num_computed_tokens == 4 + assert updated.sampling_params.temperature == 0.8 + assert updated.sampling_params.max_tokens == 50 + assert updated.block_ids == ([0, 1],) + assert updated.output_token_ids == [] + assert runner.requests[req_id] is updated + assert req_id not in runner.input_batch.req_id_to_index + + +def test_update_streaming_request_with_prompt_embeds_tensor(mock_runner_with_input_batch): + """Test streaming update when prompt_embeds is a tensor. + + The base _update_streaming_request assigns prompt_embeds directly + and updates num_prompt_tokens based on the embeds length. + """ + runner = mock_runner_with_input_batch + req_id = "streaming_embed_req" + + initial_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=None, + mm_features=[], + sampling_params=SamplingParams(), + pooling_params=None, + generator=None, + block_ids=([0],), + num_computed_tokens=2, + output_token_ids=[5], + prompt_embeds=torch.randn(2, 16), + ) + runner.requests[req_id] = initial_state + runner.input_batch.add_request(initial_state) + + embed_tensor = torch.randn(4, 16) + new_req_data = Mock() + new_req_data.prompt_token_ids = None + new_req_data.mm_features = [] + new_req_data.prompt_embeds = embed_tensor + new_req_data.sampling_params = SamplingParams() + new_req_data.pooling_params = None + new_req_data.block_ids = ([0, 1],) + new_req_data.num_computed_tokens = 3 + + updated = OmniGPUModelRunner._update_streaming_request( + runner, req_id, new_req_data + ) + + assert updated.prompt_embeds is not None + assert updated.prompt_embeds.shape == (4, 16) + assert torch.equal(updated.prompt_embeds, embed_tensor) + assert updated.output_token_ids == [] + assert updated.num_prompt_tokens == 4 + + +def test_update_streaming_request_clears_output_and_updates_prompt(mock_runner_with_input_batch): + """Test that output_token_ids are cleared and prompt grows. + + When a streaming request is updated, the intermediate output tokens + are moved into the prompt_token_ids. The output_token_ids must be + cleared since those tokens are now part of the prompt context. + """ + runner = mock_runner_with_input_batch + req_id = "streaming_clear_req" + + initial_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=[1, 2, 3], + mm_features=[], + sampling_params=SamplingParams(), + pooling_params=None, + generator=None, + block_ids=([0],), + num_computed_tokens=3, + output_token_ids=[10, 11, 12], + ) + runner.requests[req_id] = initial_state + runner.input_batch.add_request(initial_state) + + # New prompt includes old prompt + old outputs + new input + new_req_data = Mock() + new_req_data.prompt_token_ids = [1, 2, 3, 10, 11, 12, 20, 21] + new_req_data.mm_features = [] + new_req_data.prompt_embeds = None + new_req_data.sampling_params = SamplingParams(temperature=0.5) + new_req_data.pooling_params = None + new_req_data.block_ids = ([0, 1],) + new_req_data.num_computed_tokens = 6 # old prompt + old outputs + + updated = OmniGPUModelRunner._update_streaming_request( + runner, req_id, new_req_data + ) + + assert updated.prompt_token_ids == [1, 2, 3, 10, 11, 12, 20, 21] + assert updated.output_token_ids == [] + assert updated.num_computed_tokens == 6 + assert updated.num_prompt_tokens == 8 + + +def test_update_streaming_request_with_multimodal_features(mock_runner_with_input_batch): + """Test streaming update preserves multimodal features correctly.""" + from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, + ) + + runner = mock_runner_with_input_batch + req_id = "streaming_mm_req" + + mm_feature_1 = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("audio"), + modality="audio", + identifier="audio_1", + mm_position=PlaceholderRange(offset=2, length=10), + ) + + initial_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=[1, 2] + [0] * 10 + [3], + mm_features=[mm_feature_1], + sampling_params=SamplingParams(), + pooling_params=None, + generator=None, + block_ids=([0],), + num_computed_tokens=13, + output_token_ids=[100], + ) + runner.requests[req_id] = initial_state + runner.input_batch.add_request(initial_state) + + mm_feature_2 = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("audio"), + modality="audio", + identifier="audio_2", + mm_position=PlaceholderRange(offset=14, length=5), + ) + + new_req_data = Mock() + new_req_data.prompt_token_ids = [1, 2] + [0] * 10 + [3, 100] + [0] * 5 + [4] + new_req_data.mm_features = [mm_feature_1, mm_feature_2] + new_req_data.prompt_embeds = None + new_req_data.sampling_params = SamplingParams(temperature=0.7) + new_req_data.pooling_params = None + new_req_data.block_ids = ([0, 1],) + new_req_data.num_computed_tokens = 13 + new_req_data.additional_information = None + + updated = OmniGPUModelRunner._update_streaming_request( + runner, req_id, new_req_data + ) + + assert len(updated.mm_features) == 2 + assert updated.mm_features[0] == mm_feature_1 + assert updated.mm_features[1] == mm_feature_2 + assert len(updated.prompt_token_ids) == 20 + assert updated.output_token_ids == [] + assert updated.num_computed_tokens == 13 + assert updated.sampling_params.temperature == 0.7 + assert req_id not in runner.input_batch.req_id_to_index + + +def test_update_states_routes_to_streaming_update(mock_runner_with_input_batch): + """Test that _update_states detects existing request and routes to + _update_streaming_request instead of creating a new CachedRequestState. + """ + runner = mock_runner_with_input_batch + req_id = "route_test_req" + + initial_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=[1, 2, 3], + mm_features=[], + sampling_params=SamplingParams(), + pooling_params=None, + generator=None, + block_ids=([0],), + num_computed_tokens=3, + output_token_ids=[10], + ) + runner.requests[req_id] = initial_state + + # The key behavior: if req_id is already in runner.requests when + # processing scheduled_new_reqs, it should call _update_streaming_request + assert req_id in runner.requests + assert initial_state.output_token_ids == [10] + + new_req_data = Mock() + new_req_data.prompt_token_ids = [1, 2, 3, 10, 4] + new_req_data.mm_features = [] + new_req_data.prompt_embeds = None + new_req_data.sampling_params = SamplingParams() + new_req_data.pooling_params = None + new_req_data.block_ids = ([0, 1],) + new_req_data.num_computed_tokens = 4 + new_req_data.additional_information = None + + updated = OmniGPUModelRunner._update_streaming_request( + runner, req_id, new_req_data + ) + + # Same object should be reused + assert updated is initial_state + assert updated.prompt_token_ids == [1, 2, 3, 10, 4] + assert updated.output_token_ids == [] diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index 4fe136288a6..c9c2c45354b 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -113,6 +113,7 @@ def __post_init__( mm_processor_cache_gb: float | None, mm_processor_cache_type: MMCacheType | None, mm_shm_cache_max_object_size_mb: int | None, + mm_encoder_only: bool | None, mm_encoder_tp_mode: MMEncoderTPMode | None, mm_encoder_attn_backend: AttentionBackendEnum | str | None, interleave_mm_strings: bool | None, @@ -265,6 +266,7 @@ def __post_init__( mm_processor_cache_gb=mm_processor_cache_gb, mm_processor_cache_type=mm_processor_cache_type, mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, + mm_encoder_only=mm_encoder_only, mm_encoder_tp_mode=mm_encoder_tp_mode, mm_encoder_attn_backend=mm_encoder_attn_backend, interleave_mm_strings=interleave_mm_strings, diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 03eb5ff212f..4a01759ab8a 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -295,6 +295,7 @@ def update_from_output( if request.output_token_ids: stopped = check_stop(request, self.max_model_len) routed_experts = None + finish_reason = None if stopped: # [Omni] Handle routed experts if enabled if self.vllm_config.model_config.enable_return_routed_experts: @@ -317,7 +318,16 @@ def update_from_output( routed_experts = self.routed_experts_reader.get_routed_experts(indices=slot_mapping) - kv_transfer_params = self._free_request(request) + # Capture finish_reason BEFORE _handle_stopped_request, which may + # reset the status to WAITING for streaming requests that continue. + finish_reason = request.get_finished_reason() + + # Handle resumable/streaming input requests - only actually finish + # if _handle_stopped_request returns True + finished = self._handle_stopped_request(request) + if finished: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: stopped_running_reqs.add(request) else: @@ -344,13 +354,13 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None or kv_transfer_params: + if new_token_ids or pooler_output is not None or kv_transfer_params or stopped: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, - finish_reason=request.get_finished_reason(), + finish_reason=finish_reason, new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, pooling_output=pooler_output, diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py index eb81f38dc66..0bac59f201c 100644 --- a/vllm_omni/engine/input_processor.py +++ b/vllm_omni/engine/input_processor.py @@ -75,17 +75,21 @@ def _dtype_to_name(dtype: torch.dtype) -> str: def __init__( self, vllm_config: VllmConfig, - tokenizer: TokenizerLike, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - super().__init__(vllm_config, tokenizer, mm_registry) + super().__init__(vllm_config, mm_registry) self.input_preprocessor = OmniInputPreprocessor( self.model_config, - self.tokenizer, + vllm_config.observability_config, mm_registry, mm_processor_cache=self.mm_processor_cache, ) + @property + def tokenizer(self) -> TokenizerLike | None: + """Get tokenizer from input_preprocessor.""" + return self.input_preprocessor.tokenizer + def process_inputs( self, request_id: str, @@ -97,6 +101,7 @@ def process_inputs( trace_headers: Mapping[str, str] | None = None, priority: int = 0, data_parallel_rank: int | None = None, + resumable: bool = False, ) -> tuple[str | None, OmniEngineCoreRequest]: """Process input prompt into an engine core request. @@ -281,4 +286,5 @@ def process_inputs( trace_headers=trace_headers, prompt_embeds=prompt_embeds_payload, additional_information=additional_information_payload, + resumable=resumable, ) diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index ab5cf878c5a..917a2fdcaa3 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -131,79 +131,20 @@ def _consolidate_multimodal_tensors(self) -> None: except Exception: logger.exception("Error consolidating multimodal tensors") - # Override: do not route to pooling-only path; always create completion - # outputs, and attach pooling_result into the CompletionOutput. - def make_request_output( - self, - new_token_ids: list[int], - pooling_output: torch.Tensor | None, - finish_reason: FinishReason | None, - stop_reason: int | str | None, - kv_transfer_params: dict[str, Any] | None = None, - ) -> OmniRequestOutput | PoolingRequestOutput | None: - """Create a request output from generation results. - - Creates a RequestOutput or PoolingRequestOutput from the generated - tokens and accumulated multimodal outputs. Attaches multimodal - tensors to the completion output if available. - - Args: - new_token_ids: List of newly generated token IDs - pooling_output: Optional pooling output tensor - finish_reason: Optional finish reason indicating why generation stopped - stop_reason: Optional stop reason (token ID or stop string) - kv_transfer_params: Optional KV cache transfer parameters - - Returns: - OmniRequestOutput or PoolingRequestOutput if output should be - emitted (based on finish status and output kind), None otherwise - """ - finished = finish_reason is not None - final_only = self.output_kind == RequestOutputKind.FINAL_ONLY - - if not finished and final_only: - return None - - if self.stream_interval > 1: - assert self.detokenizer is not None - - # Send output request only when - # 1. It has finished, or - # 2. It is the first token, or - # 3. It has reached the stream interval number of tokens - if not ( - finished - or self.sent_tokens_offset == 0 - or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset >= self.stream_interval - ): - return None - - if self.output_kind == RequestOutputKind.DELTA: - # Send tokens from the offset in DELTA mode, otherwise all - # tokens are sent. - new_token_ids = self.detokenizer.output_token_ids[self.sent_tokens_offset :] - self.sent_tokens_offset = len(self.detokenizer.output_token_ids) - - request_id = self.request_id - output = self._new_completion_output(new_token_ids, finish_reason, stop_reason) - - if self.parent_req is None: - outputs = [output] - else: - request_id, outputs, finished = self.parent_req.get_outputs(request_id, output) - if not outputs: - return None - - return self._new_request_output(request_id, outputs, finished, kv_transfer_params) + # Note: make_request_output is inherited from base RequestState. + # The multimodal output attachment is done in _new_completion_output below. def _new_completion_output( self, token_ids: list[int], finish_reason: FinishReason | None, stop_reason: int | str | None, + routed_experts: Any = None, ) -> Any: - # Reuse base text/logprobs logic, then annotate with pooling_result. - base_output = super()._new_completion_output(token_ids, finish_reason, stop_reason) + # Reuse base text/logprobs logic, then annotate with multimodal output. + base_output = super()._new_completion_output( + token_ids, finish_reason, stop_reason, routed_experts + ) try: if self.mm_accumulated is not None: # Attach accumulated multimodal dict on the completion output @@ -286,13 +227,13 @@ def add_request( parent_req: Optional parent request for parallel sampling request_index: Index of the request in the batch queue: Optional queue for collecting outputs - - Raises: - ValueError: If the request ID is already registered """ request_id = request.request_id - if request_id in self.request_states: - raise ValueError(f"Request id {request_id} already running.") + req_state = self.request_states.get(request_id) + if req_state is not None: + # Streaming input session update - use base class method + self._update_streaming_request_state(req_state, request, prompt) + return req_state = OmniRequestState.from_new_request( tokenizer=self.tokenizer, @@ -304,9 +245,22 @@ def add_request( log_stats=self.log_stats, stream_interval=self.stream_interval, ) + if self._requests_drained.is_set(): + self._requests_drained.clear() self.request_states[request_id] = req_state if parent_req: self.parent_requests[parent_req.request_id] = parent_req + # Track the external_req_id -> [internal_req_id, ...] mapping + self.external_req_ids[req_state.external_req_id].append(request_id) + + def _finish_request(self, req_state: RequestState) -> None: + """Clean up a finished request with omni-specific cleanup.""" + # Cleanup per-request mm state before calling base + if isinstance(req_state, OmniRequestState): + req_state.mm_accumulated = None + req_state.mm_type = None + # Call base class implementation + super()._finish_request(req_state) def process_outputs( self, @@ -395,6 +349,10 @@ def process_outputs( kv_transfer_params, ) if ro: + # For streaming input, mark output as not finished while streaming + if isinstance(req_state, OmniRequestState) and req_state.streaming_input: + ro.finished = False + # Attach accumulated multimodal payload if any try: if isinstance(req_state, OmniRequestState) and req_state.mm_accumulated is not None: @@ -410,21 +368,24 @@ def process_outputs( # 4) Free completed if finish_reason is not None: - self.request_states.pop(req_id) - parent_req = req_state.parent_req - if parent_req and not parent_req.child_requests: - self.parent_requests.pop(parent_req.request_id, None) - if not self.request_states: - self._requests_drained.set() - if not eco.finished: - reqs_to_abort.append(req_id) - self._update_stats_from_finished(req_state, finish_reason, iteration_stats) - if self.tracer: - self.do_tracing(eco, req_state, iteration_stats) - # Cleanup per-request mm state - if isinstance(req_state, OmniRequestState): - req_state.mm_accumulated = None - req_state.mm_type = None + # Handle streaming input state + if isinstance(req_state, OmniRequestState) and req_state.streaming_input: + if req_state.input_chunk_queue: + update = req_state.input_chunk_queue.popleft() + req_state.apply_streaming_update(update) + else: + req_state.input_chunk_queue = None + else: + self._finish_request(req_state) + if not eco.finished: + # If req not finished in EngineCore, but Detokenizer + # detected stop string, abort needed in EngineCore. + reqs_to_abort.append(req_id) + + # Track per-request stats + self._update_stats_from_finished(req_state, finish_reason, iteration_stats) + if self.tracer: + self.do_tracing(eco, req_state, iteration_stats) return OutputProcessorOutput( request_outputs=request_outputs, diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 88a290cdfce..e9b8b561652 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -192,11 +192,9 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: if stage.vllm_config is not None and stage.tokenizer is not None: try: vllm_config = stage.vllm_config - tokenizer = stage.tokenizer - # Initialize input_processor + # Initialize input_processor (tokenizer is obtained internally via renderer) self.input_processor = OmniInputProcessor( vllm_config=vllm_config, - tokenizer=tokenizer, ) # Initialize model_config self.model_config = vllm_config.model_config diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py index 287f12b9ed7..57ca9c4a63c 100644 --- a/vllm_omni/entrypoints/async_omni_llm.py +++ b/vllm_omni/entrypoints/async_omni_llm.py @@ -6,19 +6,21 @@ from typing import TYPE_CHECKING import torch -import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.tokenizers import cached_tokenizer_from_config from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext from vllm.utils.func_utils import deprecate_kwargs from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import EngineCoreClient -from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager +from vllm.v1.executor import Executor +from vllm.v1.metrics.loggers import ( + StatLoggerFactory, + StatLoggerManager, + load_stat_logger_plugin_factories, +) from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs from vllm_omni.engine.input_processor import OmniInputProcessor @@ -51,6 +53,7 @@ class AsyncOmniLLM(AsyncLLM): stat_loggers: Customized stat loggers for the engine. If not provided, default stat loggers will be used. Note: Stat logger interface may change in V1. + aggregate_engine_logging: Whether to aggregate engine logging client_addresses: Optional dictionary mapping client names to addresses client_count: Total number of clients (default: 1) client_index: Index of this client (default: 0) @@ -68,6 +71,7 @@ def __init__( log_requests: bool = True, start_engine_loop: bool = True, stat_loggers: list[StatLoggerFactory] | None = None, + aggregate_engine_logging: bool = False, client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, @@ -76,6 +80,7 @@ def __init__( Create an AsyncOmniLLM. Args: + engine_args: AsyncOmniEngineArgs containing engine configuration vllm_config: global configuration. executor_class: an Executor impl, e.g. MultiprocExecutor. log_stats: Whether to log stats. @@ -88,6 +93,7 @@ def __init__( If not provided, default stat loggers will be used. PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. + aggregate_engine_logging: Whether to aggregate engine logging. Returns: None @@ -100,29 +106,26 @@ def __init__( self.observability_config = vllm_config.observability_config self.log_requests = log_requests - self.log_stats = log_stats or (stat_loggers is not None) - if not log_stats and stat_loggers is not None: + custom_stat_loggers = list(stat_loggers or []) + custom_stat_loggers.extend(load_stat_logger_plugin_factories()) + + has_custom_loggers = bool(custom_stat_loggers) + self.log_stats = log_stats or has_custom_loggers + if not log_stats and has_custom_loggers: logger.info( - "AsyncLLM created with log_stats=False and non-empty custom logger list; " + "AsyncOmniLLM created with log_stats=False and non-empty custom logger list; " "enabling logging without default stat loggers" ) - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - # Tokenizer (+ ensure liveness if running in another process). - tokenizer = cached_tokenizer_from_config(model_config=vllm_config.model_config) - # InputProcessor (converts Inputs --> EngineCoreRequests). self.input_processor = OmniInputProcessor( vllm_config=vllm_config, - tokenizer=tokenizer, mm_registry=mm_registry, ) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). self.output_processor = MultimodalOutputProcessor( - tokenizer=tokenizer, + tokenizer=self.tokenizer, log_stats=self.log_stats, engine_core_output_type=engine_args.engine_output_type, ) @@ -151,9 +154,10 @@ def __init__( self.logger_manager = StatLoggerManager( vllm_config=vllm_config, engine_idxs=self.engine_core.engine_ranks_managed, - custom_stat_loggers=stat_loggers, + custom_stat_loggers=custom_stat_loggers, enable_default_loggers=log_stats, client_count=client_count, + aggregate_engine_logging=aggregate_engine_logging, ) self.logger_manager.log_engine_initialized() @@ -165,21 +169,25 @@ def __init__( except RuntimeError: pass - if envs.VLLM_TORCH_PROFILER_DIR and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: + if ( + vllm_config.profiler_config.profiler == "torch" + and not vllm_config.profiler_config.ignore_frontend + ): + profiler_dir = vllm_config.profiler_config.torch_profiler_dir logger.info( "Torch profiler enabled. AsyncOmniLLM CPU traces will be collected under %s", - envs.VLLM_TORCH_PROFILER_DIR, + profiler_dir, ) worker_name = f"{socket.gethostname()}_{os.getpid()}.async_omni_llm" self.profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, ], - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_stack=vllm_config.profiler_config.torch_profiler_with_stack, on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, + profiler_dir, worker_name=worker_name, - use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, + use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip, ), ) else: @@ -198,6 +206,7 @@ def from_vllm_config( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: list[StatLoggerFactory] | None = None, enable_log_requests: bool = False, + aggregate_engine_logging: bool = False, disable_log_stats: bool = False, client_addresses: dict[str, str] | None = None, client_count: int = 1, @@ -212,6 +221,7 @@ def from_vllm_config( stat_loggers=stat_loggers, log_requests=enable_log_requests, log_stats=not disable_log_stats, + aggregate_engine_logging=aggregate_engine_logging, usage_context=usage_context, client_addresses=client_addresses, client_count=client_count, diff --git a/vllm_omni/entrypoints/chat_utils.py b/vllm_omni/entrypoints/chat_utils.py index 0fdef5edbb7..08181faa80c 100644 --- a/vllm_omni/entrypoints/chat_utils.py +++ b/vllm_omni/entrypoints/chat_utils.py @@ -15,7 +15,7 @@ MultiModalDataDict, MultiModalUUIDDict, _AssistantParser, - _ChatTemplateContentFormat, + ChatTemplateContentFormat, _ContentPart, _get_full_multimodal_text_prompt, _parse_chat_message_content_part, @@ -128,7 +128,7 @@ def _cleanup_file_sync(file_path: str) -> None: def parse_chat_messages_futures( messages: list[ChatCompletionMessageParam], model_config: ModelConfig, - content_format: _ChatTemplateContentFormat, + content_format: ChatTemplateContentFormat, mm_processor_kwargs: dict[str, Any] | None = None, ) -> tuple[ list[ConversationMessage], @@ -161,7 +161,7 @@ def parse_chat_messages_futures( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, - content_format: _ChatTemplateContentFormat, + content_format: ChatTemplateContentFormat, interleave_strings: bool, mm_processor_kwargs: dict[str, Any] | None = None, ) -> list[ConversationMessage]: diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 6bb8ac3663f..55b8803478d 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -162,7 +162,7 @@ def __init__( engine_core_output_type=engine_args.engine_output_type, ) self.llm_engine.input_processor = OmniInputProcessor( - vllm_config=self.llm_engine.vllm_config, tokenizer=self.llm_engine.tokenizer + vllm_config=self.llm_engine.vllm_config, ) self.engine_class = type(self.llm_engine) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 2369a10001e..71d99a4d796 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -18,7 +18,7 @@ from starlette.datastructures import State from starlette.routing import Route from vllm.engine.protocol import EngineClient -from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages +from vllm.entrypoints.anthropic.serving import AnthropicServingMessages from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.api_server import ( @@ -29,15 +29,16 @@ setup_server, ) from vllm.entrypoints.openai.orca_metrics import metrics_header -from vllm.entrypoints.openai.protocol import ( +from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ChatCompletionResponse, - ErrorResponse, ) -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels -from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses -from vllm.entrypoints.openai.serving_transcription import ( +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses +from vllm.entrypoints.openai.translations.serving import ( OpenAIServingTranscription, OpenAIServingTranslation, ) @@ -48,10 +49,10 @@ from vllm.entrypoints.pooling.score.serving import ServingScores from vllm.entrypoints.serve.disagg.serving import ServingTokens from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization -from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer +from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.utils import ( load_aware_call, - process_chat_template, process_lora_modules, with_cancellation, ) @@ -357,11 +358,7 @@ async def omni_init_app_state( supported_tasks = set(await engine_client.get_supported_tasks()) logger.info("Supported tasks: %s", supported_tasks) - resolved_chat_template = await process_chat_template( - args.chat_template, - engine_client, - vllm_config.model_config if vllm_config is not None else None, - ) + resolved_chat_template = load_chat_template(args.chat_template) if args.tool_server == "demo": tool_server: ToolServer | None = DemoToolServer() @@ -403,7 +400,6 @@ async def omni_init_app_state( if not hasattr(engine_client, "input_processor") or engine_client.input_processor is None: engine_client.input_processor = OmniInputProcessor( vllm_config=vllm_config, - tokenizer=tokenizer, ) logger.info("Initialized input_processor for AsyncOmni") diff --git a/vllm_omni/entrypoints/openai/protocol/chat_completion.py b/vllm_omni/entrypoints/openai/protocol/chat_completion.py index d0c83f56f8b..d106b7aa7ae 100644 --- a/vllm_omni/entrypoints/openai/protocol/chat_completion.py +++ b/vllm_omni/entrypoints/openai/protocol/chat_completion.py @@ -1,4 +1,4 @@ -from vllm.entrypoints.openai.protocol import ( +from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionStreamResponse, ) diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 5d9f76d70c3..5487ef4fb03 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -24,17 +24,15 @@ ChatCompletionMessageParam, ChatTemplateContentFormatOption, ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, get_history_tool_calls_cnt, make_tool_call_id, - resolve_chat_template_content_format, ) +from vllm.renderers.hf import resolve_chat_template_content_format from vllm.entrypoints.openai.parser.harmony_utils import ( get_streamable_parser_for_assistant, parse_chat_output, ) -from vllm.entrypoints.openai.protocol import ( +from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, @@ -42,21 +40,23 @@ ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, - DeltaFunctionCall, DeltaMessage, + FunctionCall, + FunctionDefinition, + ToolCall, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, DeltaToolCall, ErrorInfo, ErrorResponse, - FunctionCall, - FunctionDefinition, PromptTokenUsageInfo, RequestResponseMetadata, - ResponsesRequest, - ToolCall, UsageInfo, ) -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_engine import ( +from vllm.entrypoints.openai.responses.protocol import ResponsesRequest +from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat +from vllm.entrypoints.openai.engine.serving import ( ChatLikeRequest, clamp_prompt_logprobs, ) @@ -376,16 +376,18 @@ async def _preprocess_chat( if tokenizer is None: request_prompt = "placeholder" elif isinstance(tokenizer, MistralTokenizer): - request_prompt = apply_mistral_chat_template( + from vllm.renderers.mistral import safe_apply_chat_template as mistral_apply + request_prompt = mistral_apply( tokenizer, messages=messages, **_chat_template_kwargs, ) else: - request_prompt = apply_hf_chat_template( - tokenizer=tokenizer, + from vllm.renderers.hf import safe_apply_chat_template as hf_apply + request_prompt = hf_apply( + model_config, + tokenizer, conversation=conversation, - model_config=model_config, **_chat_template_kwargs, ) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index c6b87810e98..77be4cc8f35 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -2,7 +2,7 @@ from fastapi import Request from fastapi.responses import Response -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.logger import init_logger from vllm.utils import random_uuid diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py index 2d479062eb2..e5662c37c8b 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py @@ -83,12 +83,14 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems from vllm.multimodal.processing import ( - MultiModalPromptUpdates, - PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, ) +from vllm.multimodal.processing.processor import ( + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, +) from vllm.sequence import IntermediateTensors from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import ( diff --git a/vllm_omni/request.py b/vllm_omni/request.py index ca5bd437da6..99d74e84bc6 100644 --- a/vllm_omni/request.py +++ b/vllm_omni/request.py @@ -68,4 +68,5 @@ def from_engine_core_request( trace_headers=request.trace_headers, block_hasher=block_hasher, additional_information=request.additional_information, + resumable=request.resumable, ) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index ecc839fd3a2..47d70332192 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -214,6 +214,13 @@ def execute_model( use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens_padded if pad_attn else num_tokens_unpadded, + num_reqs_padded=num_reqs_padded if pad_attn else num_reqs, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + attn_metadata, spec_decode_common_attn_metadata = self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_tokens_padded=num_tokens_padded if pad_attn else None, @@ -225,6 +232,7 @@ def execute_model( use_spec_decode=use_spec_decode, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, ) ( @@ -255,6 +263,7 @@ def execute_model( cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, ), record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 87e1419c87c..1e3b9be4584 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( RoutedExpertsCapturer, ) -from vllm.model_executor.models.interfaces import supports_mm_encoder_only from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.outputs import AsyncModelRunnerOutput, make_empty_encoder_model_runner_output @@ -190,6 +189,13 @@ def execute_model( use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens_padded if pad_attn else num_tokens_unpadded, + num_reqs_padded=num_reqs_padded if pad_attn else num_reqs, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + attn_metadata, spec_decode_common_attn_metadata = self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_tokens_padded=num_tokens_padded if pad_attn else None, @@ -201,6 +207,7 @@ def execute_model( use_spec_decode=use_spec_decode, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, ) ( @@ -235,6 +242,7 @@ def execute_model( cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, ), record_function_or_nullcontext("Forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, @@ -435,7 +443,9 @@ def _dummy_run( remove_lora: If False, dummy LoRAs are not destroyed after the run activate_lora: If False, dummy_run is performed without LoRAs. """ - if supports_mm_encoder_only(self.model): + # Check if mm_encoder_only mode is enabled (configuration-based check) + mm_config = getattr(self.model_config, 'multimodal_config', None) + if mm_config is not None and getattr(mm_config, 'mm_encoder_only', False): # The current dummy run only covers LM execution, so we can skip it. # mm encoder dummy run may need to add in the future. return torch.tensor([]), torch.tensor([]) @@ -539,6 +549,13 @@ def _dummy_run( attn_metadata: PerLayerAttnMetadata | None = None + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens, + num_reqs_padded=num_reqs_padded, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: @@ -564,6 +581,7 @@ def _dummy_run( max_query_len=max_query_len, ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, for_cudagraph_capture=is_graph_capturing, + slot_mappings=slot_mappings_by_group, ) with self.maybe_dummy_run_with_lora( @@ -628,6 +646,7 @@ def _dummy_run( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, ), ): outputs = self.model( @@ -664,6 +683,7 @@ def _dummy_run( num_tokens, use_cudagraphs=use_cudagraphs, is_graph_capturing=is_graph_capturing, + slot_mappings=slot_mappings, ) # We register layerwise NVTX hooks here after the first dynamo tracing is diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index a486212aa6c..f83186cfcd7 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.models.interfaces import supports_mm_encoder_only, supports_mrope +from vllm.model_executor.models.interfaces import supports_mrope from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.sampling_params import SamplingType from vllm.utils.import_utils import LazyLoader @@ -170,6 +170,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id + if req_id in self.requests: + # For streaming case only. + req_state = self._update_streaming_request(req_id, new_req_data) + reqs_to_add.append(req_state) + continue + sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params @@ -447,7 +453,9 @@ def _dummy_run( remove_lora: If False, dummy LoRAs are not destroyed after the run activate_lora: If False, dummy_run is performed without LoRAs. """ - if supports_mm_encoder_only(self.model): + # Check if mm_encoder_only mode is enabled (configuration-based check) + mm_config = getattr(self.model_config, 'multimodal_config', None) + if mm_config is not None and getattr(mm_config, 'mm_encoder_only', False): # The current dummy run only covers LM execution, so we can skip it. # mm encoder dummy run may need to add in the future. return torch.tensor([]), torch.tensor([]) @@ -551,6 +559,13 @@ def _dummy_run( attn_metadata: PerLayerAttnMetadata | None = None + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens, + num_reqs_padded=num_reqs_padded, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: @@ -576,6 +591,7 @@ def _dummy_run( max_query_len=max_query_len, ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, for_cudagraph_capture=is_graph_capturing, + slot_mappings=slot_mappings_by_group, ) with self.maybe_dummy_run_with_lora( @@ -640,6 +656,7 @@ def _dummy_run( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, ), ): if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): @@ -683,6 +700,7 @@ def _dummy_run( num_tokens, use_cudagraphs=use_cudagraphs, is_graph_capturing=is_graph_capturing, + slot_mappings=slot_mappings, ) # We register layerwise NVTX hooks here after the first dynamo tracing is diff --git a/vllm_omni/worker/npu/npu_generation_model_runner.py b/vllm_omni/worker/npu/npu_generation_model_runner.py index e6641ca124a..39c90e30a8b 100644 --- a/vllm_omni/worker/npu/npu_generation_model_runner.py +++ b/vllm_omni/worker/npu/npu_generation_model_runner.py @@ -14,7 +14,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.logger import logger -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalKwargsItem as MultiModalKwargs from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder