From b4f9e9e3e9bbbc167259bd4aaf4e17c11c5a9781 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Mon, 20 Apr 2026 02:25:06 +0000 Subject: [PATCH 1/2] [Fix] harden image edit input validation for HTTP 400 Signed-off-by: david6666666 <530634352@qq.com> --- tests/comfyui/test_comfyui_integration.py | 1 - .../openai_api/test_image_server.py | 42 ++++++++++++++++ vllm_omni/entrypoints/openai/api_server.py | 48 ++++++++++++++++++- 3 files changed, 88 insertions(+), 3 deletions(-) diff --git a/tests/comfyui/test_comfyui_integration.py b/tests/comfyui/test_comfyui_integration.py index 5164f3b9acb..80e86d82412 100644 --- a/tests/comfyui/test_comfyui_integration.py +++ b/tests/comfyui/test_comfyui_integration.py @@ -523,7 +523,6 @@ def run_server(): "Qwen/Qwen-Image-Edit", True, id="image-to-image-dalle-endpoint", - marks=pytest.mark.skip(reason="Temporarily disabled due to failure."), ), pytest.param( ServerCase( diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 607b3eaa813..a2b0a0e9422 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -1038,6 +1038,48 @@ def _fail_load(*args, **kwargs): assert engine.captured_prompt is None +def test_image_edit_ignores_non_concrete_diffusion_input_limits(async_omni_test_client): + engine = async_omni_test_client.app.state.engine_client + engine.get_diffusion_od_config = lambda: SimpleNamespace( + supports_multimodal_inputs=SimpleNamespace(), + max_multimodal_image_inputs=SimpleNamespace(), + ) + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[("image", make_test_image_bytes((16, 16)))], + data={"prompt": "hello world."}, + ) + + assert response.status_code == 200 + + +def test_image_edit_maps_worker_side_input_limit_error_to_400(async_omni_test_client): + engine = async_omni_test_client.app.state.engine_client + engine.get_diffusion_od_config = lambda: None + + async def _fail_generate(*args, **kwargs): + raise RuntimeError("Received 5 input images. At most 4 images are supported by this model.") + yield + + engine.generate = _fail_generate + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[ + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ], + data={"prompt": "hello world."}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model." + + def test_image_edit_parameter_pass(async_omni_test_client): img_bytes_1 = make_test_image_bytes((16, 16)) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 745b719d5b2..2329e019070 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -10,6 +10,7 @@ # Image generation API imports import random +import re import time from argparse import Namespace from collections.abc import AsyncIterator @@ -122,6 +123,10 @@ router = APIRouter() MAX_UINT32_SEED = 2**32 - 1 +_MULTI_IMAGE_SINGLE_INPUT_DETAIL = "Received multiple input images. Only a single image is supported by this model." +_MULTI_IMAGE_LIMIT_DETAIL_RE = re.compile( + r"Received \d+ input images\. At most \d+ images are supported by this model\." +) profiler_router = APIRouter() @@ -1686,6 +1691,12 @@ async def edit_images( logger.error(f"Validation error: {e}") raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)) except Exception as e: + image_input_validation_detail = _extract_image_input_validation_detail(e) + if image_input_validation_detail is not None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=image_input_validation_detail, + ) from e logger.exception(f"Image edit failed: {e}") raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Image edit failed: {str(e)}") @@ -1746,10 +1757,43 @@ def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int # config is not exposed on the serving surface. return None - if not bool(getattr(od_config, "supports_multimodal_inputs", False)): + supports_multimodal_inputs = _normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None)) + if supports_multimodal_inputs is None: + return None + + if not supports_multimodal_inputs: return 1 - return getattr(od_config, "max_multimodal_image_inputs", None) + return _normalize_positive_int(getattr(od_config, "max_multimodal_image_inputs", None)) + + +def _normalize_optional_bool(value: Any) -> bool | None: + if isinstance(value, bool): + return value + return None + + +def _normalize_positive_int(value: Any) -> int | None: + if isinstance(value, bool) or not isinstance(value, int): + return None + if value < 1: + return None + return value + + +def _extract_image_input_validation_detail(exc: Exception) -> str | None: + current: BaseException | None = exc + seen: set[int] = set() + while current is not None and id(current) not in seen: + seen.add(id(current)) + message = str(current) + if message == _MULTI_IMAGE_SINGLE_INPUT_DETAIL: + return message + match = _MULTI_IMAGE_LIMIT_DETAIL_RE.search(message) + if match is not None: + return match.group(0) + current = current.__cause__ or current.__context__ + return None def _get_lora_from_json_str(lora_body): From ea770bb8d1362a363bcbfc5d5e9e289e4fae691c Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Mon, 20 Apr 2026 08:24:07 +0000 Subject: [PATCH 2/2] [Fix] address PR2930 review feedback Signed-off-by: david6666666 <530634352@qq.com> --- .../openai_api/test_image_server.py | 3 +- tests/entrypoints/test_async_omni_abort.py | 30 ++++++++++++++++ vllm_omni/diffusion/data.py | 1 + vllm_omni/diffusion/diffusion_engine.py | 5 +-- .../inline_stage_diffusion_client.py | 3 ++ .../pipeline_qwen_image_edit_plus.py | 3 +- vllm_omni/diffusion/stage_diffusion_client.py | 1 + vllm_omni/diffusion/stage_diffusion_proc.py | 4 +++ .../diffusion/worker/diffusion_worker.py | 5 +-- vllm_omni/engine/orchestrator.py | 1 + vllm_omni/entrypoints/async_omni.py | 3 +- vllm_omni/entrypoints/errors.py | 18 ++++++++++ vllm_omni/entrypoints/openai/api_server.py | 36 ++----------------- vllm_omni/entrypoints/openai/serving_chat.py | 17 ++++++--- vllm_omni/entrypoints/openai/utils.py | 6 ++++ vllm_omni/outputs.py | 1 + 16 files changed, 92 insertions(+), 45 deletions(-) create mode 100644 vllm_omni/entrypoints/errors.py diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index a2b0a0e9422..a2353f2f5d1 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -18,6 +18,7 @@ from pytest_mock import MockerFixture from vllm import SamplingParams +from vllm_omni.entrypoints.errors import InputValidationError from vllm_omni.entrypoints.openai.image_api_utils import ( encode_image_base64, parse_size, @@ -1059,7 +1060,7 @@ def test_image_edit_maps_worker_side_input_limit_error_to_400(async_omni_test_cl engine.get_diffusion_od_config = lambda: None async def _fail_generate(*args, **kwargs): - raise RuntimeError("Received 5 input images. At most 4 images are supported by this model.") + raise InputValidationError("Received 5 input images. At most 4 images are supported by this model.") yield engine.generate = _fail_generate diff --git a/tests/entrypoints/test_async_omni_abort.py b/tests/entrypoints/test_async_omni_abort.py index b34652162d0..0df017b9ed9 100644 --- a/tests/entrypoints/test_async_omni_abort.py +++ b/tests/entrypoints/test_async_omni_abort.py @@ -4,6 +4,7 @@ import pytest from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.errors import InputValidationError pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -83,3 +84,32 @@ async def collect_outputs(request_id): ] asyncio.run(run_test()) + + +def test_process_orchestrator_results_restores_input_validation_error(): + async def run_test(): + omni = object.__new__(AsyncOmni) + req_state = SimpleNamespace(queue=asyncio.Queue(), stage_id=0) + omni.request_states = {"req-1": req_state} + + await req_state.queue.put( + { + "request_id": "req-1", + "stage_id": 0, + "error": "Received 5 input images. At most 4 images are supported by this model.", + "error_type": "InputValidationError", + } + ) + + with pytest.raises(InputValidationError, match="At most 4 images are supported by this model"): + async for _ in AsyncOmni._process_orchestrator_results( + omni, + "req-1", + metrics=None, + final_stage_id_for_e2e=0, + req_start_ts={}, + wall_start_ts=0.0, + ): + pass + + asyncio.run(run_test()) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 0a19eb11974..bb90146b38e 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -759,6 +759,7 @@ class DiffusionOutput: trajectory_log_probs: torch.Tensor | dict | None = None trajectory_decoded: list[Image.Image] | None = None error: str | None = None + error_type: str | None = None aborted: bool = False abort_message: str | None = None diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index fe940d623e5..d5fc7df9a84 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -30,6 +30,7 @@ from vllm_omni.diffusion.sched import RequestScheduler, SchedulerInterface, StepScheduler from vllm_omni.diffusion.sched.interface import DiffusionRequestStatus from vllm_omni.diffusion.worker.utils import RunnerOutput +from vllm_omni.entrypoints.errors import get_serialized_error_type, restore_serialized_error from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput @@ -122,7 +123,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: if output.aborted: raise DiffusionRequestAbortedError(output.abort_message or "Diffusion request aborted.") if output.error: - raise RuntimeError(f"{output.error}") + raise restore_serialized_error(output.error, output.error_type) logger.info("Generation completed successfully.") if output.output is None: @@ -364,7 +365,7 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus req_id=sched_req_id, step_index=None, finished=True, - result=DiffusionOutput(error=str(exc)), + result=DiffusionOutput(error=str(exc), error_type=get_serialized_error_type(exc)), ) self._process_aborts_queue() diff --git a/vllm_omni/diffusion/inline_stage_diffusion_client.py b/vllm_omni/diffusion/inline_stage_diffusion_client.py index a33a3e95619..5f026bf37d4 100644 --- a/vllm_omni/diffusion/inline_stage_diffusion_client.py +++ b/vllm_omni/diffusion/inline_stage_diffusion_client.py @@ -20,6 +20,7 @@ from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.engine.stage_init_utils import StageMetadata +from vllm_omni.entrypoints.errors import get_serialized_error_type from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -123,6 +124,7 @@ async def _dispatch_request( images=[], ) error_output.error = str(e) + error_output.error_type = get_serialized_error_type(e) self._output_queue.put_nowait(error_output) finally: self._tasks.pop(request_id, None) @@ -223,6 +225,7 @@ async def _dispatch_batch( images=[], ) error_output.error = str(e) + error_output.error_type = get_serialized_error_type(e) self._output_queue.put_nowait(error_output) finally: self._tasks.pop(request_id, None) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 2e25d0fe6b2..2a2202d2335 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -48,6 +48,7 @@ normalize_min_aligned_size, ) from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.entrypoints.errors import InputValidationError from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, @@ -101,7 +102,7 @@ def pre_process_func( if not isinstance(raw_image, list): raw_image = [raw_image] if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES: - raise ValueError( + raise InputValidationError( f"Received {len(raw_image)} input images. " f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model." ) diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py index 480d113d192..181b02ba3ed 100644 --- a/vllm_omni/diffusion/stage_diffusion_client.py +++ b/vllm_omni/diffusion/stage_diffusion_client.py @@ -178,6 +178,7 @@ def _drain_responses(self) -> None: images=[], ) error_output.error = error_msg + error_output.error_type = msg.get("error_type") self._output_queue.put_nowait(error_output) # Fields that are subprocess-local and cannot be serialized across diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py index eced444fd32..8525d82efe0 100644 --- a/vllm_omni/diffusion/stage_diffusion_proc.py +++ b/vllm_omni/diffusion/stage_diffusion_proc.py @@ -30,6 +30,7 @@ OmniMsgpackDecoder, OmniMsgpackEncoder, ) +from vllm_omni.entrypoints.errors import get_serialized_error_type from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -339,6 +340,7 @@ async def _dispatch_request( "type": "error", "request_id": request_id, "error": str(e), + "error_type": get_serialized_error_type(e), } ) ) @@ -394,6 +396,7 @@ async def _dispatch_batch( "type": "error", "request_id": rid, "error": str(e), + "error_type": get_serialized_error_type(e), } ) ) @@ -443,6 +446,7 @@ async def _dispatch_batch( "type": "error", "rpc_id": rpc_id, "error": str(e), + "error_type": get_serialized_error_type(e), } ) ) diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 160309e0d8d..9a27b0bd175 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -41,6 +41,7 @@ from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner from vllm_omni.diffusion.worker.utils import RunnerOutput +from vllm_omni.entrypoints.errors import get_serialized_error_type from vllm_omni.lora.request import LoRARequest from vllm_omni.platforms import current_omni_platform from vllm_omni.profiler import OmniTorchProfilerWrapper, create_omni_profiler @@ -501,7 +502,7 @@ def worker_busy_loop(self) -> None: except Exception as e: logger.error(f"Error processing RPC: {e}", exc_info=True) if self.result_mq is not None: - self.return_result(DiffusionOutput(error=str(e))) + self.return_result(DiffusionOutput(error=str(e), error_type=get_serialized_error_type(e))) elif isinstance(msg, dict) and msg.get("type") == "shutdown": logger.info("Worker %s: Received shutdown message", self.gpu_id) @@ -517,7 +518,7 @@ def worker_busy_loop(self) -> None: f"Error executing forward in event loop: {e}", exc_info=True, ) - output = DiffusionOutput(error=str(e)) + output = DiffusionOutput(error=str(e), error_type=get_serialized_error_type(e)) try: self.return_result(output) diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index a5d7ad032e9..f7a8b7c61f3 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -281,6 +281,7 @@ async def _orchestration_loop(self) -> None: "request_id": parent_id, "stage_id": stage_id, "error": output.error, + "error_type": output.error_type, } ) role_map = self._companion_map.get(parent_id, {}) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 9606cc80d0d..99ceb5177f0 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -25,6 +25,7 @@ from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.entrypoints.client_request_state import ClientRequestState +from vllm_omni.entrypoints.errors import restore_serialized_error from vllm_omni.entrypoints.omni_base import OmniBase from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.outputs import OmniRequestOutput @@ -441,7 +442,7 @@ async def _process_orchestrator_results( stage_id, result["error"], ) - raise RuntimeError(result) + raise restore_serialized_error(result["error"], result.get("error_type")) # Process the result (constructs OmniRequestOutput) output_to_yield = self._process_single_result( diff --git a/vllm_omni/entrypoints/errors.py b/vllm_omni/entrypoints/errors.py new file mode 100644 index 00000000000..e29a917a96d --- /dev/null +++ b/vllm_omni/entrypoints/errors.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 + + +class InputValidationError(ValueError): + def __init__(self, message: str = "Invalid input.") -> None: + super().__init__(message) + + +def get_serialized_error_type(exc: BaseException) -> str | None: + if isinstance(exc, InputValidationError): + return InputValidationError.__name__ + return None + + +def restore_serialized_error(message: str, error_type: str | None) -> Exception: + if error_type == InputValidationError.__name__: + return InputValidationError(message) + return RuntimeError(message) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 2329e019070..fb1b4ad20d2 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -10,7 +10,6 @@ # Image generation API imports import random -import re import time from argparse import Namespace from collections.abc import AsyncIterator @@ -115,7 +114,7 @@ from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo, ReferenceImage from vllm_omni.entrypoints.openai.storage import STORAGE_MANAGER from vllm_omni.entrypoints.openai.stores import VIDEO_STORE, VIDEO_TASKS -from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request +from vllm_omni.entrypoints.openai.utils import get_stage_type, normalize_optional_bool, parse_lora_request from vllm_omni.entrypoints.openai.video_api_utils import decode_input_reference from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt @@ -123,10 +122,6 @@ router = APIRouter() MAX_UINT32_SEED = 2**32 - 1 -_MULTI_IMAGE_SINGLE_INPUT_DETAIL = "Received multiple input images. Only a single image is supported by this model." -_MULTI_IMAGE_LIMIT_DETAIL_RE = re.compile( - r"Received \d+ input images\. At most \d+ images are supported by this model\." -) profiler_router = APIRouter() @@ -1691,12 +1686,6 @@ async def edit_images( logger.error(f"Validation error: {e}") raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)) except Exception as e: - image_input_validation_detail = _extract_image_input_validation_detail(e) - if image_input_validation_detail is not None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=image_input_validation_detail, - ) from e logger.exception(f"Image edit failed: {e}") raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Image edit failed: {str(e)}") @@ -1757,7 +1746,7 @@ def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int # config is not exposed on the serving surface. return None - supports_multimodal_inputs = _normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None)) + supports_multimodal_inputs = normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None)) if supports_multimodal_inputs is None: return None @@ -1767,12 +1756,6 @@ def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int return _normalize_positive_int(getattr(od_config, "max_multimodal_image_inputs", None)) -def _normalize_optional_bool(value: Any) -> bool | None: - if isinstance(value, bool): - return value - return None - - def _normalize_positive_int(value: Any) -> int | None: if isinstance(value, bool) or not isinstance(value, int): return None @@ -1781,21 +1764,6 @@ def _normalize_positive_int(value: Any) -> int | None: return value -def _extract_image_input_validation_detail(exc: Exception) -> str | None: - current: BaseException | None = exc - seen: set[int] = set() - while current is not None and id(current) not in seen: - seen.add(id(current)) - message = str(current) - if message == _MULTI_IMAGE_SINGLE_INPUT_DETAIL: - return message - match = _MULTI_IMAGE_LIMIT_DETAIL_RE.search(message) - if match is not None: - return match.group(0) - current = current.__cause__ or current.__context__ - return None - - def _get_lora_from_json_str(lora_body): if lora_body is None: return None diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 8cddac6a6c5..c627ee689c3 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -88,6 +88,7 @@ from vllm_omni.entrypoints.openai.utils import ( get_stage_type, get_supported_speakers_from_hf_config, + normalize_optional_bool, parse_lora_request, validate_requested_speaker, ) @@ -2255,8 +2256,12 @@ async def generate_diffusion_images( gen_prompt["multi_modal_data"] = {"image": pil_images[0]} else: od_config = getattr(engine, "od_config", None) - supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) - if od_config is None: + supports_multimodal_inputs = ( + True + if od_config is None + else normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None)) + ) + if supports_multimodal_inputs is None: supports_multimodal_inputs = True if supports_multimodal_inputs: gen_prompt["multi_modal_data"] = {"image": pil_images} @@ -2459,8 +2464,12 @@ async def _create_diffusion_chat_completion( gen_prompt["multi_modal_data"]["image"] = pil_images[0] else: od_config = getattr(self._diffusion_engine, "od_config", None) - supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) - if od_config is None: + supports_multimodal_inputs = ( + True + if od_config is None + else normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None)) + ) + if supports_multimodal_inputs is None: # TODO: entry is asyncOmni. We hack the od config here. supports_multimodal_inputs = True if supports_multimodal_inputs: diff --git a/vllm_omni/entrypoints/openai/utils.py b/vllm_omni/entrypoints/openai/utils.py index f411526fdb2..0742fb9d396 100644 --- a/vllm_omni/entrypoints/openai/utils.py +++ b/vllm_omni/entrypoints/openai/utils.py @@ -83,3 +83,9 @@ def validate_requested_speaker(speaker: str | None, supported_speakers: set[str] if supported_speakers and normalized not in supported_speakers: raise ValueError(f"Invalid speaker '{speaker}'. Supported: {', '.join(sorted(supported_speakers))}") return normalized + + +def normalize_optional_bool(value: Any) -> bool | None: + if isinstance(value, bool): + return value + return None diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index c02c0c1427c..bdffa733418 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -102,6 +102,7 @@ class OmniRequestOutput: # error handling error: str | None = None + error_type: str | None = None @classmethod def from_pipeline(