diff --git a/tests/comfyui/test_comfyui_integration.py b/tests/comfyui/test_comfyui_integration.py index 5164f3b9acb..80e86d82412 100644 --- a/tests/comfyui/test_comfyui_integration.py +++ b/tests/comfyui/test_comfyui_integration.py @@ -523,7 +523,6 @@ def run_server(): "Qwen/Qwen-Image-Edit", True, id="image-to-image-dalle-endpoint", - marks=pytest.mark.skip(reason="Temporarily disabled due to failure."), ), pytest.param( ServerCase( diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 607b3eaa813..a2353f2f5d1 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -18,6 +18,7 @@ from pytest_mock import MockerFixture from vllm import SamplingParams +from vllm_omni.entrypoints.errors import InputValidationError from vllm_omni.entrypoints.openai.image_api_utils import ( encode_image_base64, parse_size, @@ -1038,6 +1039,48 @@ def _fail_load(*args, **kwargs): assert engine.captured_prompt is None +def test_image_edit_ignores_non_concrete_diffusion_input_limits(async_omni_test_client): + engine = async_omni_test_client.app.state.engine_client + engine.get_diffusion_od_config = lambda: SimpleNamespace( + supports_multimodal_inputs=SimpleNamespace(), + max_multimodal_image_inputs=SimpleNamespace(), + ) + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[("image", make_test_image_bytes((16, 16)))], + data={"prompt": "hello world."}, + ) + + assert response.status_code == 200 + + +def test_image_edit_maps_worker_side_input_limit_error_to_400(async_omni_test_client): + engine = async_omni_test_client.app.state.engine_client + engine.get_diffusion_od_config = lambda: None + + async def _fail_generate(*args, **kwargs): + raise InputValidationError("Received 5 input images. At most 4 images are supported by this model.") + yield + + engine.generate = _fail_generate + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[ + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ], + data={"prompt": "hello world."}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model." + + def test_image_edit_parameter_pass(async_omni_test_client): img_bytes_1 = make_test_image_bytes((16, 16)) diff --git a/tests/entrypoints/test_async_omni_abort.py b/tests/entrypoints/test_async_omni_abort.py index b34652162d0..0df017b9ed9 100644 --- a/tests/entrypoints/test_async_omni_abort.py +++ b/tests/entrypoints/test_async_omni_abort.py @@ -4,6 +4,7 @@ import pytest from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.errors import InputValidationError pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -83,3 +84,32 @@ async def collect_outputs(request_id): ] asyncio.run(run_test()) + + +def test_process_orchestrator_results_restores_input_validation_error(): + async def run_test(): + omni = object.__new__(AsyncOmni) + req_state = SimpleNamespace(queue=asyncio.Queue(), stage_id=0) + omni.request_states = {"req-1": req_state} + + await req_state.queue.put( + { + "request_id": "req-1", + "stage_id": 0, + "error": "Received 5 input images. At most 4 images are supported by this model.", + "error_type": "InputValidationError", + } + ) + + with pytest.raises(InputValidationError, match="At most 4 images are supported by this model"): + async for _ in AsyncOmni._process_orchestrator_results( + omni, + "req-1", + metrics=None, + final_stage_id_for_e2e=0, + req_start_ts={}, + wall_start_ts=0.0, + ): + pass + + asyncio.run(run_test()) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 0a19eb11974..bb90146b38e 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -759,6 +759,7 @@ class DiffusionOutput: trajectory_log_probs: torch.Tensor | dict | None = None trajectory_decoded: list[Image.Image] | None = None error: str | None = None + error_type: str | None = None aborted: bool = False abort_message: str | None = None diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index fe940d623e5..d5fc7df9a84 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -30,6 +30,7 @@ from vllm_omni.diffusion.sched import RequestScheduler, SchedulerInterface, StepScheduler from vllm_omni.diffusion.sched.interface import DiffusionRequestStatus from vllm_omni.diffusion.worker.utils import RunnerOutput +from vllm_omni.entrypoints.errors import get_serialized_error_type, restore_serialized_error from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput @@ -122,7 +123,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: if output.aborted: raise DiffusionRequestAbortedError(output.abort_message or "Diffusion request aborted.") if output.error: - raise RuntimeError(f"{output.error}") + raise restore_serialized_error(output.error, output.error_type) logger.info("Generation completed successfully.") if output.output is None: @@ -364,7 +365,7 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus req_id=sched_req_id, step_index=None, finished=True, - result=DiffusionOutput(error=str(exc)), + result=DiffusionOutput(error=str(exc), error_type=get_serialized_error_type(exc)), ) self._process_aborts_queue() diff --git a/vllm_omni/diffusion/inline_stage_diffusion_client.py b/vllm_omni/diffusion/inline_stage_diffusion_client.py index a33a3e95619..5f026bf37d4 100644 --- a/vllm_omni/diffusion/inline_stage_diffusion_client.py +++ b/vllm_omni/diffusion/inline_stage_diffusion_client.py @@ -20,6 +20,7 @@ from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.engine.stage_init_utils import StageMetadata +from vllm_omni.entrypoints.errors import get_serialized_error_type from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -123,6 +124,7 @@ async def _dispatch_request( images=[], ) error_output.error = str(e) + error_output.error_type = get_serialized_error_type(e) self._output_queue.put_nowait(error_output) finally: self._tasks.pop(request_id, None) @@ -223,6 +225,7 @@ async def _dispatch_batch( images=[], ) error_output.error = str(e) + error_output.error_type = get_serialized_error_type(e) self._output_queue.put_nowait(error_output) finally: self._tasks.pop(request_id, None) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 2e25d0fe6b2..2a2202d2335 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -48,6 +48,7 @@ normalize_min_aligned_size, ) from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.entrypoints.errors import InputValidationError from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, @@ -101,7 +102,7 @@ def pre_process_func( if not isinstance(raw_image, list): raw_image = [raw_image] if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES: - raise ValueError( + raise InputValidationError( f"Received {len(raw_image)} input images. " f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model." ) diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py index 480d113d192..181b02ba3ed 100644 --- a/vllm_omni/diffusion/stage_diffusion_client.py +++ b/vllm_omni/diffusion/stage_diffusion_client.py @@ -178,6 +178,7 @@ def _drain_responses(self) -> None: images=[], ) error_output.error = error_msg + error_output.error_type = msg.get("error_type") self._output_queue.put_nowait(error_output) # Fields that are subprocess-local and cannot be serialized across diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py index eced444fd32..8525d82efe0 100644 --- a/vllm_omni/diffusion/stage_diffusion_proc.py +++ b/vllm_omni/diffusion/stage_diffusion_proc.py @@ -30,6 +30,7 @@ OmniMsgpackDecoder, OmniMsgpackEncoder, ) +from vllm_omni.entrypoints.errors import get_serialized_error_type from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -339,6 +340,7 @@ async def _dispatch_request( "type": "error", "request_id": request_id, "error": str(e), + "error_type": get_serialized_error_type(e), } ) ) @@ -394,6 +396,7 @@ async def _dispatch_batch( "type": "error", "request_id": rid, "error": str(e), + "error_type": get_serialized_error_type(e), } ) ) @@ -443,6 +446,7 @@ async def _dispatch_batch( "type": "error", "rpc_id": rpc_id, "error": str(e), + "error_type": get_serialized_error_type(e), } ) ) diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 160309e0d8d..9a27b0bd175 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -41,6 +41,7 @@ from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner from vllm_omni.diffusion.worker.utils import RunnerOutput +from vllm_omni.entrypoints.errors import get_serialized_error_type from vllm_omni.lora.request import LoRARequest from vllm_omni.platforms import current_omni_platform from vllm_omni.profiler import OmniTorchProfilerWrapper, create_omni_profiler @@ -501,7 +502,7 @@ def worker_busy_loop(self) -> None: except Exception as e: logger.error(f"Error processing RPC: {e}", exc_info=True) if self.result_mq is not None: - self.return_result(DiffusionOutput(error=str(e))) + self.return_result(DiffusionOutput(error=str(e), error_type=get_serialized_error_type(e))) elif isinstance(msg, dict) and msg.get("type") == "shutdown": logger.info("Worker %s: Received shutdown message", self.gpu_id) @@ -517,7 +518,7 @@ def worker_busy_loop(self) -> None: f"Error executing forward in event loop: {e}", exc_info=True, ) - output = DiffusionOutput(error=str(e)) + output = DiffusionOutput(error=str(e), error_type=get_serialized_error_type(e)) try: self.return_result(output) diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index a5d7ad032e9..f7a8b7c61f3 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -281,6 +281,7 @@ async def _orchestration_loop(self) -> None: "request_id": parent_id, "stage_id": stage_id, "error": output.error, + "error_type": output.error_type, } ) role_map = self._companion_map.get(parent_id, {}) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 9606cc80d0d..99ceb5177f0 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -25,6 +25,7 @@ from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.entrypoints.client_request_state import ClientRequestState +from vllm_omni.entrypoints.errors import restore_serialized_error from vllm_omni.entrypoints.omni_base import OmniBase from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.outputs import OmniRequestOutput @@ -441,7 +442,7 @@ async def _process_orchestrator_results( stage_id, result["error"], ) - raise RuntimeError(result) + raise restore_serialized_error(result["error"], result.get("error_type")) # Process the result (constructs OmniRequestOutput) output_to_yield = self._process_single_result( diff --git a/vllm_omni/entrypoints/errors.py b/vllm_omni/entrypoints/errors.py new file mode 100644 index 00000000000..e29a917a96d --- /dev/null +++ b/vllm_omni/entrypoints/errors.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 + + +class InputValidationError(ValueError): + def __init__(self, message: str = "Invalid input.") -> None: + super().__init__(message) + + +def get_serialized_error_type(exc: BaseException) -> str | None: + if isinstance(exc, InputValidationError): + return InputValidationError.__name__ + return None + + +def restore_serialized_error(message: str, error_type: str | None) -> Exception: + if error_type == InputValidationError.__name__: + return InputValidationError(message) + return RuntimeError(message) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 745b719d5b2..fb1b4ad20d2 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -114,7 +114,7 @@ from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo, ReferenceImage from vllm_omni.entrypoints.openai.storage import STORAGE_MANAGER from vllm_omni.entrypoints.openai.stores import VIDEO_STORE, VIDEO_TASKS -from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request +from vllm_omni.entrypoints.openai.utils import get_stage_type, normalize_optional_bool, parse_lora_request from vllm_omni.entrypoints.openai.video_api_utils import decode_input_reference from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt @@ -1746,10 +1746,22 @@ def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int # config is not exposed on the serving surface. return None - if not bool(getattr(od_config, "supports_multimodal_inputs", False)): + supports_multimodal_inputs = normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None)) + if supports_multimodal_inputs is None: + return None + + if not supports_multimodal_inputs: return 1 - return getattr(od_config, "max_multimodal_image_inputs", None) + return _normalize_positive_int(getattr(od_config, "max_multimodal_image_inputs", None)) + + +def _normalize_positive_int(value: Any) -> int | None: + if isinstance(value, bool) or not isinstance(value, int): + return None + if value < 1: + return None + return value def _get_lora_from_json_str(lora_body): diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 8cddac6a6c5..c627ee689c3 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -88,6 +88,7 @@ from vllm_omni.entrypoints.openai.utils import ( get_stage_type, get_supported_speakers_from_hf_config, + normalize_optional_bool, parse_lora_request, validate_requested_speaker, ) @@ -2255,8 +2256,12 @@ async def generate_diffusion_images( gen_prompt["multi_modal_data"] = {"image": pil_images[0]} else: od_config = getattr(engine, "od_config", None) - supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) - if od_config is None: + supports_multimodal_inputs = ( + True + if od_config is None + else normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None)) + ) + if supports_multimodal_inputs is None: supports_multimodal_inputs = True if supports_multimodal_inputs: gen_prompt["multi_modal_data"] = {"image": pil_images} @@ -2459,8 +2464,12 @@ async def _create_diffusion_chat_completion( gen_prompt["multi_modal_data"]["image"] = pil_images[0] else: od_config = getattr(self._diffusion_engine, "od_config", None) - supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) - if od_config is None: + supports_multimodal_inputs = ( + True + if od_config is None + else normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None)) + ) + if supports_multimodal_inputs is None: # TODO: entry is asyncOmni. We hack the od config here. supports_multimodal_inputs = True if supports_multimodal_inputs: diff --git a/vllm_omni/entrypoints/openai/utils.py b/vllm_omni/entrypoints/openai/utils.py index f411526fdb2..0742fb9d396 100644 --- a/vllm_omni/entrypoints/openai/utils.py +++ b/vllm_omni/entrypoints/openai/utils.py @@ -83,3 +83,9 @@ def validate_requested_speaker(speaker: str | None, supported_speakers: set[str] if supported_speakers and normalized not in supported_speakers: raise ValueError(f"Invalid speaker '{speaker}'. Supported: {', '.join(sorted(supported_speakers))}") return normalized + + +def normalize_optional_bool(value: Any) -> bool | None: + if isinstance(value, bool): + return value + return None diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index c02c0c1427c..bdffa733418 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -102,6 +102,7 @@ class OmniRequestOutput: # error handling error: str | None = None + error_type: str | None = None @classmethod def from_pipeline(