Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/comfyui/test_comfyui_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ def run_server():
"Qwen/Qwen-Image-Edit",
True,
id="image-to-image-dalle-endpoint",
marks=pytest.mark.skip(reason="Temporarily disabled due to failure."),
),
pytest.param(
ServerCase(
Expand Down
43 changes: 43 additions & 0 deletions tests/entrypoints/openai_api/test_image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytest_mock import MockerFixture
from vllm import SamplingParams

from vllm_omni.entrypoints.errors import InputValidationError
from vllm_omni.entrypoints.openai.image_api_utils import (
encode_image_base64,
parse_size,
Expand Down Expand Up @@ -1038,6 +1039,48 @@ def _fail_load(*args, **kwargs):
assert engine.captured_prompt is None


def test_image_edit_ignores_non_concrete_diffusion_input_limits(async_omni_test_client):
engine = async_omni_test_client.app.state.engine_client
engine.get_diffusion_od_config = lambda: SimpleNamespace(
supports_multimodal_inputs=SimpleNamespace(),
max_multimodal_image_inputs=SimpleNamespace(),
)

response = async_omni_test_client.post(
"/v1/images/edits",
files=[("image", make_test_image_bytes((16, 16)))],
data={"prompt": "hello world."},
)

assert response.status_code == 200


def test_image_edit_maps_worker_side_input_limit_error_to_400(async_omni_test_client):
engine = async_omni_test_client.app.state.engine_client
engine.get_diffusion_od_config = lambda: None

async def _fail_generate(*args, **kwargs):
raise InputValidationError("Received 5 input images. At most 4 images are supported by this model.")
yield

engine.generate = _fail_generate

response = async_omni_test_client.post(
"/v1/images/edits",
files=[
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
("image", make_test_image_bytes((16, 16))),
],
data={"prompt": "hello world."},
)

assert response.status_code == 400
assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."


def test_image_edit_parameter_pass(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))

Expand Down
30 changes: 30 additions & 0 deletions tests/entrypoints/test_async_omni_abort.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.errors import InputValidationError

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

Expand Down Expand Up @@ -83,3 +84,32 @@ async def collect_outputs(request_id):
]

asyncio.run(run_test())


def test_process_orchestrator_results_restores_input_validation_error():
async def run_test():
omni = object.__new__(AsyncOmni)
req_state = SimpleNamespace(queue=asyncio.Queue(), stage_id=0)
omni.request_states = {"req-1": req_state}

await req_state.queue.put(
{
"request_id": "req-1",
"stage_id": 0,
"error": "Received 5 input images. At most 4 images are supported by this model.",
"error_type": "InputValidationError",
}
)

with pytest.raises(InputValidationError, match="At most 4 images are supported by this model"):
async for _ in AsyncOmni._process_orchestrator_results(
omni,
"req-1",
metrics=None,
final_stage_id_for_e2e=0,
req_start_ts={},
wall_start_ts=0.0,
):
pass

asyncio.run(run_test())
1 change: 1 addition & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ class DiffusionOutput:
trajectory_log_probs: torch.Tensor | dict | None = None
trajectory_decoded: list[Image.Image] | None = None
error: str | None = None
error_type: str | None = None
aborted: bool = False
abort_message: str | None = None

Expand Down
5 changes: 3 additions & 2 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm_omni.diffusion.sched import RequestScheduler, SchedulerInterface, StepScheduler
from vllm_omni.diffusion.sched.interface import DiffusionRequestStatus
from vllm_omni.diffusion.worker.utils import RunnerOutput
from vllm_omni.entrypoints.errors import get_serialized_error_type, restore_serialized_error
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from vllm_omni.outputs import OmniRequestOutput

Expand Down Expand Up @@ -122,7 +123,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
if output.aborted:
raise DiffusionRequestAbortedError(output.abort_message or "Diffusion request aborted.")
if output.error:
raise RuntimeError(f"{output.error}")
raise restore_serialized_error(output.error, output.error_type)
logger.info("Generation completed successfully.")

if output.output is None:
Expand Down Expand Up @@ -364,7 +365,7 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus
req_id=sched_req_id,
step_index=None,
finished=True,
result=DiffusionOutput(error=str(exc)),
result=DiffusionOutput(error=str(exc), error_type=get_serialized_error_type(exc)),
)

self._process_aborts_queue()
Expand Down
3 changes: 3 additions & 0 deletions vllm_omni/diffusion/inline_stage_diffusion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.engine.stage_init_utils import StageMetadata
from vllm_omni.entrypoints.errors import get_serialized_error_type
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput

Expand Down Expand Up @@ -123,6 +124,7 @@ async def _dispatch_request(
images=[],
)
error_output.error = str(e)
error_output.error_type = get_serialized_error_type(e)
self._output_queue.put_nowait(error_output)
finally:
self._tasks.pop(request_id, None)
Expand Down Expand Up @@ -223,6 +225,7 @@ async def _dispatch_batch(
images=[],
)
error_output.error = str(e)
error_output.error_type = get_serialized_error_type(e)
self._output_queue.put_nowait(error_output)
finally:
self._tasks.pop(request_id, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
normalize_min_aligned_size,
)
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.entrypoints.errors import InputValidationError
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
Expand Down Expand Up @@ -101,7 +102,7 @@ def pre_process_func(
if not isinstance(raw_image, list):
raw_image = [raw_image]
if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES:
raise ValueError(
raise InputValidationError(
f"Received {len(raw_image)} input images. "
f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model."
)
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/diffusion/stage_diffusion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def _drain_responses(self) -> None:
images=[],
)
error_output.error = error_msg
error_output.error_type = msg.get("error_type")
self._output_queue.put_nowait(error_output)

# Fields that are subprocess-local and cannot be serialized across
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/diffusion/stage_diffusion_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OmniMsgpackDecoder,
OmniMsgpackEncoder,
)
from vllm_omni.entrypoints.errors import get_serialized_error_type
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput

Expand Down Expand Up @@ -339,6 +340,7 @@ async def _dispatch_request(
"type": "error",
"request_id": request_id,
"error": str(e),
"error_type": get_serialized_error_type(e),
}
)
)
Expand Down Expand Up @@ -394,6 +396,7 @@ async def _dispatch_batch(
"type": "error",
"request_id": rid,
"error": str(e),
"error_type": get_serialized_error_type(e),
}
)
)
Expand Down Expand Up @@ -443,6 +446,7 @@ async def _dispatch_batch(
"type": "error",
"rpc_id": rpc_id,
"error": str(e),
"error_type": get_serialized_error_type(e),
}
)
)
Expand Down
5 changes: 3 additions & 2 deletions vllm_omni/diffusion/worker/diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput
from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner
from vllm_omni.diffusion.worker.utils import RunnerOutput
from vllm_omni.entrypoints.errors import get_serialized_error_type
from vllm_omni.lora.request import LoRARequest
from vllm_omni.platforms import current_omni_platform
from vllm_omni.profiler import OmniTorchProfilerWrapper, create_omni_profiler
Expand Down Expand Up @@ -501,7 +502,7 @@ def worker_busy_loop(self) -> None:
except Exception as e:
logger.error(f"Error processing RPC: {e}", exc_info=True)
if self.result_mq is not None:
self.return_result(DiffusionOutput(error=str(e)))
self.return_result(DiffusionOutput(error=str(e), error_type=get_serialized_error_type(e)))

elif isinstance(msg, dict) and msg.get("type") == "shutdown":
logger.info("Worker %s: Received shutdown message", self.gpu_id)
Expand All @@ -517,7 +518,7 @@ def worker_busy_loop(self) -> None:
f"Error executing forward in event loop: {e}",
exc_info=True,
)
output = DiffusionOutput(error=str(e))
output = DiffusionOutput(error=str(e), error_type=get_serialized_error_type(e))

try:
self.return_result(output)
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ async def _orchestration_loop(self) -> None:
"request_id": parent_id,
"stage_id": stage_id,
"error": output.error,
"error_type": output.error_type,
}
)
role_map = self._companion_map.get(parent_id, {})
Expand Down
3 changes: 2 additions & 1 deletion vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.v1.engine.exceptions import EngineDeadError

from vllm_omni.entrypoints.client_request_state import ClientRequestState
from vllm_omni.entrypoints.errors import restore_serialized_error
from vllm_omni.entrypoints.omni_base import OmniBase
from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics
from vllm_omni.outputs import OmniRequestOutput
Expand Down Expand Up @@ -441,7 +442,7 @@ async def _process_orchestrator_results(
stage_id,
result["error"],
)
raise RuntimeError(result)
raise restore_serialized_error(result["error"], result.get("error_type"))

# Process the result (constructs OmniRequestOutput)
output_to_yield = self._process_single_result(
Expand Down
18 changes: 18 additions & 0 deletions vllm_omni/entrypoints/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0


class InputValidationError(ValueError):
def __init__(self, message: str = "Invalid input.") -> None:
super().__init__(message)


def get_serialized_error_type(exc: BaseException) -> str | None:
if isinstance(exc, InputValidationError):
return InputValidationError.__name__
return None


def restore_serialized_error(message: str, error_type: str | None) -> Exception:
if error_type == InputValidationError.__name__:
return InputValidationError(message)
return RuntimeError(message)
18 changes: 15 additions & 3 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo, ReferenceImage
from vllm_omni.entrypoints.openai.storage import STORAGE_MANAGER
from vllm_omni.entrypoints.openai.stores import VIDEO_STORE, VIDEO_TASKS
from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request
from vllm_omni.entrypoints.openai.utils import get_stage_type, normalize_optional_bool, parse_lora_request
from vllm_omni.entrypoints.openai.video_api_utils import decode_input_reference
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt

Expand Down Expand Up @@ -1746,10 +1746,22 @@ def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int
# config is not exposed on the serving surface.
return None

if not bool(getattr(od_config, "supports_multimodal_inputs", False)):
supports_multimodal_inputs = normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None))
if supports_multimodal_inputs is None:
return None

if not supports_multimodal_inputs:
return 1

return getattr(od_config, "max_multimodal_image_inputs", None)
return _normalize_positive_int(getattr(od_config, "max_multimodal_image_inputs", None))


def _normalize_positive_int(value: Any) -> int | None:
if isinstance(value, bool) or not isinstance(value, int):
return None
if value < 1:
return None
return value


def _get_lora_from_json_str(lora_body):
Expand Down
17 changes: 13 additions & 4 deletions vllm_omni/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from vllm_omni.entrypoints.openai.utils import (
get_stage_type,
get_supported_speakers_from_hf_config,
normalize_optional_bool,
parse_lora_request,
validate_requested_speaker,
)
Expand Down Expand Up @@ -2255,8 +2256,12 @@ async def generate_diffusion_images(
gen_prompt["multi_modal_data"] = {"image": pil_images[0]}
else:
od_config = getattr(engine, "od_config", None)
supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False)
if od_config is None:
supports_multimodal_inputs = (
True
if od_config is None
else normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None))
)
if supports_multimodal_inputs is None:
supports_multimodal_inputs = True
if supports_multimodal_inputs:
gen_prompt["multi_modal_data"] = {"image": pil_images}
Expand Down Expand Up @@ -2459,8 +2464,12 @@ async def _create_diffusion_chat_completion(
gen_prompt["multi_modal_data"]["image"] = pil_images[0]
else:
od_config = getattr(self._diffusion_engine, "od_config", None)
supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False)
if od_config is None:
supports_multimodal_inputs = (
True
if od_config is None
else normalize_optional_bool(getattr(od_config, "supports_multimodal_inputs", None))
)
if supports_multimodal_inputs is None:
# TODO: entry is asyncOmni. We hack the od config here.
supports_multimodal_inputs = True
if supports_multimodal_inputs:
Expand Down
6 changes: 6 additions & 0 deletions vllm_omni/entrypoints/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,9 @@ def validate_requested_speaker(speaker: str | None, supported_speakers: set[str]
if supported_speakers and normalized not in supported_speakers:
raise ValueError(f"Invalid speaker '{speaker}'. Supported: {', '.join(sorted(supported_speakers))}")
return normalized


def normalize_optional_bool(value: Any) -> bool | None:
if isinstance(value, bool):
return value
return None
1 change: 1 addition & 0 deletions vllm_omni/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class OmniRequestOutput:

# error handling
error: str | None = None
error_type: str | None = None

@classmethod
def from_pipeline(
Expand Down
Loading