From f6ec96f30a023b29af649d14e06f4270e5605eef Mon Sep 17 00:00:00 2001 From: princepride Date: Sun, 5 Apr 2026 12:13:11 +0000 Subject: [PATCH 1/3] [BugFix] Continue decode if don't need transfer kv cache between two stages Signed-off-by: princepride --- vllm_omni/core/sched/omni_ar_scheduler.py | 19 +++++++++++- vllm_omni/engine/async_omni_engine.py | 38 ++++++++++++++++++++++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 0956d1856a4..a2d28bf1709 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -82,6 +82,14 @@ def _get_kv_transfer_criteria(self) -> dict | None: return getattr(omni_kv_config, "kv_transfer_criteria", None) return None + def _request_omits_kv_transfer_to_next_stage(self, request: Request) -> bool: + """True when orchestrator will not run stage 1+ for this request (e.g. text-only).""" + payload = getattr(request, "additional_information", None) + if payload is None: + return False + info = deserialize_additional_information(payload) + return info.get("omni_final_stage_id") == 0 + def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int]) -> bool: """ Check triggers and process side effects (marking transfer). @@ -91,6 +99,10 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int if not self.kv_transfer_criteria: return False + # Text-only requests finalize at stage 0; do not prefill-stop for DiT KV. + if self._request_omits_kv_transfer_to_next_stage(request): + return False + if request.request_id in self.waiting_for_transfer_free: return False @@ -638,7 +650,12 @@ def _should_transfer_kv_for_request(self, req_id: str) -> bool: need_send = omni_kv_config.get("need_send_cache", False) else: need_send = getattr(omni_kv_config, "need_send_cache", False) - return need_send + if not need_send: + return False + request = self.requests.get(req_id) + if request is not None and self._request_omits_kv_transfer_to_next_stage(request): + return False + return True def has_requests(self) -> bool: """Check if there are any requests to process, including KV transfers.""" diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 092b341e42a..6b410fb57c3 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -42,7 +42,10 @@ ) from vllm_omni.engine.orchestrator import Orchestrator from vllm_omni.engine.output_processor import MultimodalOutputProcessor -from vllm_omni.engine.serialization import serialize_additional_information +from vllm_omni.engine.serialization import ( + deserialize_additional_information, + serialize_additional_information, +) from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient from vllm_omni.engine.stage_engine_core_proc import ( complete_stage_handshake, @@ -170,6 +173,38 @@ def _upgrade_to_omni_request( ) +def _apply_omni_final_stage_metadata( + request: EngineCoreRequest, + final_stage_id: int, +) -> EngineCoreRequest: + """Tag EngineCoreRequest so OmniARScheduler can skip DiT KV when final_stage_id is 0.""" + merged: dict[str, Any] = {} + if isinstance(request, OmniEngineCoreRequest) and request.additional_information is not None: + merged = deserialize_additional_information(request.additional_information) + merged["omni_final_stage_id"] = final_stage_id + payload = serialize_additional_information(merged) + return OmniEngineCoreRequest( + request_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + mm_features=request.mm_features, + sampling_params=request.sampling_params, + pooling_params=request.pooling_params, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + cache_salt=request.cache_salt, + data_parallel_rank=request.data_parallel_rank, + prompt_embeds=request.prompt_embeds, + client_index=request.client_index, + current_wave=request.current_wave, + priority=request.priority, + trace_headers=request.trace_headers, + resumable=request.resumable, + external_req_id=request.external_req_id, + reasoning_ended=request.reasoning_ended, + additional_information=payload, + ) + + def _weak_shutdown_async_omni_engine( orchestrator_thread: threading.Thread | None, request_queue: janus.Queue[dict[str, Any]] | None, @@ -713,6 +748,7 @@ def _build_add_request_message( # to match the key used in Orchestrator.request_states so that # output routing (output.request_id lookup) can find the req_state. request.external_req_id = request_id + request = _apply_omni_final_stage_metadata(request, final_stage_id) # Register with stage 0's output processor. output_prompt_text = prompt_text From 24b6e2d385b455369ca7ad8a25c9f209f5bccedc Mon Sep 17 00:00:00 2001 From: princepride Date: Sun, 5 Apr 2026 15:34:10 +0000 Subject: [PATCH 2/3] add text2text and img2text task Signed-off-by: princepride --- .../test_bagel_understanding.py | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 tests/e2e/offline_inference/test_bagel_understanding.py diff --git a/tests/e2e/offline_inference/test_bagel_understanding.py b/tests/e2e/offline_inference/test_bagel_understanding.py new file mode 100644 index 00000000000..6f95e7ee00f --- /dev/null +++ b/tests/e2e/offline_inference/test_bagel_understanding.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +End-to-end tests for Bagel text2text and img2text (understanding) tasks. + +These tests validate that the Bagel multistage pipeline correctly generates +text output for understanding tasks, matching reference results. + +Equivalent to running: + python3 examples/offline_inference/bagel/end2end.py \ + --modality text2text \ + --prompts "Where is the capital of France?" + + python3 examples/offline_inference/bagel/end2end.py \ + --modality img2text \ + --prompts "Please describe this image" \ + --image-path 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg +""" + +import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" +from pathlib import Path + +import pytest +from vllm.assets.image import ImageAsset + +from tests.conftest import modify_stage_config +from tests.utils import hardware_test +from vllm_omni.entrypoints.omni import Omni + +MODEL_NAME = "ByteDance-Seed/BAGEL-7B-MoT" +STAGE_CONFIG = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml") + +REFERENCE_TEXT_TEXT2TEXT = "The capital of France is Paris." + +REFERENCE_TEXT_IMG2TEXT = ( + "This is a photo of a wooden boardwalk or pathway that leads through " + "tall green grass. The path appears to be in a natural setting, possibly " + "a wetland or marsh area. The sky above is blue with some scattered " + "clouds, suggesting it might be a sunny day. The overall scene looks " + "peaceful and serene." +) + + +def _resolve_stage_config(config_path: str, run_level: str) -> str: + """Strip load_format: dummy for advanced_model (real weights).""" + if run_level == "advanced_model": + return modify_stage_config( + config_path, + deletes={ + "stage_args": { + 0: ["engine_args.load_format"], + 1: ["engine_args.load_format"], + } + }, + ) + return config_path + + +def _extract_text(omni_outputs: list) -> str: + """Extract generated text from OmniRequestOutput list.""" + for req_output in omni_outputs: + ro = getattr(req_output, "request_output", None) + if ro and getattr(ro, "outputs", None): + return "".join(getattr(o, "text", "") or "" for o in ro.outputs) + return "" + + +@pytest.mark.core_model +@pytest.mark.advanced_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}) +def test_bagel_text2text(run_level): + """Test Bagel text2text produces correct text output.""" + config_path = _resolve_stage_config(STAGE_CONFIG, run_level) + omni = Omni( + model=MODEL_NAME, + stage_configs_path=config_path, + stage_init_timeout=300, + ) + + try: + prompt = "<|im_start|>user\nWhere is the capital of France?<|im_end|>\n<|im_start|>assistant\n" + params_list = omni.default_sampling_params_list + omni_outputs = list( + omni.generate( + prompts=[{"prompt": prompt, "modalities": ["text"]}], + sampling_params_list=params_list, + ) + ) + + assert len(omni_outputs) > 0, "No outputs returned" + text = _extract_text(omni_outputs) + assert len(text) > 0, "Generated text is empty" + + if run_level == "advanced_model": + assert text == REFERENCE_TEXT_TEXT2TEXT, ( + f"Text mismatch: expected {REFERENCE_TEXT_TEXT2TEXT!r}, got {text!r}" + ) + finally: + omni.close() + + +@pytest.mark.core_model +@pytest.mark.advanced_model +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}) +def test_bagel_img2text(run_level): + """Test Bagel img2text produces correct text output.""" + input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB") + config_path = _resolve_stage_config(STAGE_CONFIG, run_level) + omni = Omni( + model=MODEL_NAME, + stage_configs_path=config_path, + stage_init_timeout=300, + ) + + try: + prompt = "<|im_start|>user\n<|image_pad|>\nPlease describe this image<|im_end|>\n<|im_start|>assistant\n" + params_list = omni.default_sampling_params_list + omni_outputs = list( + omni.generate( + prompts=[ + { + "prompt": prompt, + "multi_modal_data": {"image": input_image}, + "modalities": ["text"], + } + ], + sampling_params_list=params_list, + ) + ) + + assert len(omni_outputs) > 0, "No outputs returned" + text = _extract_text(omni_outputs) + assert len(text) > 0, "Generated text is empty" + + if run_level == "advanced_model": + assert text == REFERENCE_TEXT_IMG2TEXT, f"Text mismatch: expected {REFERENCE_TEXT_IMG2TEXT!r}, got {text!r}" + finally: + omni.close() From cdca252ab3151ec41f6b79b5e83bc879df4956a3 Mon Sep 17 00:00:00 2001 From: princepride Date: Sun, 5 Apr 2026 15:43:06 +0000 Subject: [PATCH 3/3] Cache _request_omits_kv_transfer result to avoid repeated deserialization deserialize_additional_information() reconstructs all entries (including tensors from bytes) on every call. Since _request_omits_kv_transfer_to_next_stage is invoked on each scheduler tick, this caused unnecessary CPU copies and memory churn during decode. Cache the boolean per request and clean up in _free_request. Signed-off-by: princepride Made-with: Cursor --- vllm_omni/core/sched/omni_ar_scheduler.py | 26 +++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index a2d28bf1709..eac737b6e66 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -64,6 +64,9 @@ def __init__(self, *args, **kwargs): # Track requests that have already triggered prefill transfer to avoid duplicates self.transfer_triggered_requests: set[str] = set() + + # Cache per-request flag to avoid repeated deserialization of additional_information + self._omits_kv_transfer_cache: dict[str, bool] = {} model_config = self.vllm_config.model_config self.chunk_transfer_adapter = None if getattr(model_config, "async_chunk", False): @@ -83,12 +86,25 @@ def _get_kv_transfer_criteria(self) -> dict | None: return None def _request_omits_kv_transfer_to_next_stage(self, request: Request) -> bool: - """True when orchestrator will not run stage 1+ for this request (e.g. text-only).""" + """True when orchestrator will not run stage 1+ for this request (e.g. text-only). + + The result is cached per request to avoid repeated deserialization of + additional_information on every scheduler tick. + """ + rid = request.request_id + cached = self._omits_kv_transfer_cache.get(rid) + if cached is not None: + return cached + payload = getattr(request, "additional_information", None) if payload is None: - return False - info = deserialize_additional_information(payload) - return info.get("omni_final_stage_id") == 0 + result = False + else: + info = deserialize_additional_information(payload) + result = info.get("omni_final_stage_id") == 0 + + self._omits_kv_transfer_cache[rid] = result + return result def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int]) -> bool: """ @@ -524,6 +540,8 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di """Mark a request as finished and free its resources.""" assert request.is_finished() + self._omits_kv_transfer_cache.pop(request.request_id, None) + # 1. Standard cleanup parts from base _free_request connector_delay_free_blocks, kv_xfer_params = self._connector_finished(request)