Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions tests/e2e/offline_inference/test_bagel_understanding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
End-to-end tests for Bagel text2text and img2text (understanding) tasks.

These tests validate that the Bagel multistage pipeline correctly generates
text output for understanding tasks, matching reference results.

Equivalent to running:
python3 examples/offline_inference/bagel/end2end.py \
--modality text2text \
--prompts "Where is the capital of France?"

python3 examples/offline_inference/bagel/end2end.py \
--modality img2text \
--prompts "Please describe this image" \
--image-path 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
"""

import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path

import pytest
from vllm.assets.image import ImageAsset

from tests.conftest import modify_stage_config
from tests.utils import hardware_test
from vllm_omni.entrypoints.omni import Omni

MODEL_NAME = "ByteDance-Seed/BAGEL-7B-MoT"
STAGE_CONFIG = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")

REFERENCE_TEXT_TEXT2TEXT = "The capital of France is Paris."

REFERENCE_TEXT_IMG2TEXT = (
"This is a photo of a wooden boardwalk or pathway that leads through "
"tall green grass. The path appears to be in a natural setting, possibly "
"a wetland or marsh area. The sky above is blue with some scattered "
"clouds, suggesting it might be a sunny day. The overall scene looks "
"peaceful and serene."
)


def _resolve_stage_config(config_path: str, run_level: str) -> str:
"""Strip load_format: dummy for advanced_model (real weights)."""
if run_level == "advanced_model":
return modify_stage_config(
config_path,
deletes={
"stage_args": {
0: ["engine_args.load_format"],
1: ["engine_args.load_format"],
}
},
)
return config_path


def _extract_text(omni_outputs: list) -> str:
"""Extract generated text from OmniRequestOutput list."""
for req_output in omni_outputs:
ro = getattr(req_output, "request_output", None)
if ro and getattr(ro, "outputs", None):
return "".join(getattr(o, "text", "") or "" for o in ro.outputs)
return ""


@pytest.mark.core_model
@pytest.mark.advanced_model
@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
def test_bagel_text2text(run_level):
"""Test Bagel text2text produces correct text output."""
config_path = _resolve_stage_config(STAGE_CONFIG, run_level)
omni = Omni(
model=MODEL_NAME,
stage_configs_path=config_path,
stage_init_timeout=300,
)

try:
prompt = "<|im_start|>user\nWhere is the capital of France?<|im_end|>\n<|im_start|>assistant\n"
params_list = omni.default_sampling_params_list
omni_outputs = list(
omni.generate(
prompts=[{"prompt": prompt, "modalities": ["text"]}],
sampling_params_list=params_list,
)
)

assert len(omni_outputs) > 0, "No outputs returned"
text = _extract_text(omni_outputs)
assert len(text) > 0, "Generated text is empty"

if run_level == "advanced_model":
assert text == REFERENCE_TEXT_TEXT2TEXT, (
f"Text mismatch: expected {REFERENCE_TEXT_TEXT2TEXT!r}, got {text!r}"
)
finally:
omni.close()


@pytest.mark.core_model
@pytest.mark.advanced_model
@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
def test_bagel_img2text(run_level):
"""Test Bagel img2text produces correct text output."""
input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB")
config_path = _resolve_stage_config(STAGE_CONFIG, run_level)
omni = Omni(
model=MODEL_NAME,
stage_configs_path=config_path,
stage_init_timeout=300,
)

try:
prompt = "<|im_start|>user\n<|image_pad|>\nPlease describe this image<|im_end|>\n<|im_start|>assistant\n"
params_list = omni.default_sampling_params_list
omni_outputs = list(
omni.generate(
prompts=[
{
"prompt": prompt,
"multi_modal_data": {"image": input_image},
"modalities": ["text"],
}
],
sampling_params_list=params_list,
)
)

assert len(omni_outputs) > 0, "No outputs returned"
text = _extract_text(omni_outputs)
assert len(text) > 0, "Generated text is empty"

if run_level == "advanced_model":
assert text == REFERENCE_TEXT_IMG2TEXT, f"Text mismatch: expected {REFERENCE_TEXT_IMG2TEXT!r}, got {text!r}"
finally:
omni.close()
37 changes: 36 additions & 1 deletion vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(self, *args, **kwargs):

# Track requests that have already triggered prefill transfer to avoid duplicates
self.transfer_triggered_requests: set[str] = set()

# Cache per-request flag to avoid repeated deserialization of additional_information
self._omits_kv_transfer_cache: dict[str, bool] = {}
model_config = self.vllm_config.model_config
self.chunk_transfer_adapter = None
if getattr(model_config, "async_chunk", False):
Expand All @@ -82,6 +85,27 @@ def _get_kv_transfer_criteria(self) -> dict | None:
return getattr(omni_kv_config, "kv_transfer_criteria", None)
return None

def _request_omits_kv_transfer_to_next_stage(self, request: Request) -> bool:
"""True when orchestrator will not run stage 1+ for this request (e.g. text-only).

The result is cached per request to avoid repeated deserialization of
additional_information on every scheduler tick.
"""
rid = request.request_id
cached = self._omits_kv_transfer_cache.get(rid)
if cached is not None:
return cached

payload = getattr(request, "additional_information", None)
if payload is None:
result = False
else:
info = deserialize_additional_information(payload)
result = info.get("omni_final_stage_id") == 0

self._omits_kv_transfer_cache[rid] = result
return result

def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int]) -> bool:
"""
Check triggers and process side effects (marking transfer).
Expand All @@ -91,6 +115,10 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
if not self.kv_transfer_criteria:
return False

# Text-only requests finalize at stage 0; do not prefill-stop for DiT KV.
if self._request_omits_kv_transfer_to_next_stage(request):
return False

if request.request_id in self.waiting_for_transfer_free:
return False

Expand Down Expand Up @@ -512,6 +540,8 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
"""Mark a request as finished and free its resources."""
assert request.is_finished()

self._omits_kv_transfer_cache.pop(request.request_id, None)

# 1. Standard cleanup parts from base _free_request
connector_delay_free_blocks, kv_xfer_params = self._connector_finished(request)

Expand Down Expand Up @@ -638,7 +668,12 @@ def _should_transfer_kv_for_request(self, req_id: str) -> bool:
need_send = omni_kv_config.get("need_send_cache", False)
else:
need_send = getattr(omni_kv_config, "need_send_cache", False)
return need_send
if not need_send:
return False
request = self.requests.get(req_id)
if request is not None and self._request_omits_kv_transfer_to_next_stage(request):
return False
return True

def has_requests(self) -> bool:
"""Check if there are any requests to process, including KV transfers."""
Expand Down
38 changes: 37 additions & 1 deletion vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
)
from vllm_omni.engine.orchestrator import Orchestrator
from vllm_omni.engine.output_processor import MultimodalOutputProcessor
from vllm_omni.engine.serialization import serialize_additional_information
from vllm_omni.engine.serialization import (
deserialize_additional_information,
serialize_additional_information,
)
from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient
from vllm_omni.engine.stage_engine_core_proc import (
complete_stage_handshake,
Expand Down Expand Up @@ -170,6 +173,38 @@ def _upgrade_to_omni_request(
)


def _apply_omni_final_stage_metadata(
request: EngineCoreRequest,
final_stage_id: int,
) -> EngineCoreRequest:
"""Tag EngineCoreRequest so OmniARScheduler can skip DiT KV when final_stage_id is 0."""
merged: dict[str, Any] = {}
if isinstance(request, OmniEngineCoreRequest) and request.additional_information is not None:
merged = deserialize_additional_information(request.additional_information)
merged["omni_final_stage_id"] = final_stage_id
payload = serialize_additional_information(merged)
return OmniEngineCoreRequest(
request_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
cache_salt=request.cache_salt,
data_parallel_rank=request.data_parallel_rank,
prompt_embeds=request.prompt_embeds,
client_index=request.client_index,
current_wave=request.current_wave,
priority=request.priority,
trace_headers=request.trace_headers,
resumable=request.resumable,
external_req_id=request.external_req_id,
reasoning_ended=request.reasoning_ended,
additional_information=payload,
)


def _weak_shutdown_async_omni_engine(
orchestrator_thread: threading.Thread | None,
request_queue: janus.Queue[dict[str, Any]] | None,
Expand Down Expand Up @@ -713,6 +748,7 @@ def _build_add_request_message(
# to match the key used in Orchestrator.request_states so that
# output routing (output.request_id lookup) can find the req_state.
request.external_req_id = request_id
request = _apply_omni_final_stage_metadata(request, final_stage_id)

# Register with stage 0's output processor.
output_prompt_text = prompt_text
Expand Down
Loading