Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import (
get_qwen_image_edit_plus_pre_process_func,
)
from vllm_omni.exceptions import OmniInputValidationError

pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]


def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path):
def _make_od_config(tmp_path: Path, *, max_multimodal_image_inputs=None):
vae_dir = tmp_path / "vae"
vae_dir.mkdir()
# Keep the mock config intentionally minimal: this test only needs the
# fields touched during pre-process initialization.
vae_dir.mkdir(exist_ok=True)
(vae_dir / "config.json").write_text(json.dumps({"z_dim": 16}))
return SimpleNamespace(model=str(tmp_path), max_multimodal_image_inputs=max_multimodal_image_inputs)


pre_process = get_qwen_image_edit_plus_pre_process_func(SimpleNamespace(model=str(tmp_path)))
def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path):
od_config = _make_od_config(tmp_path)
pre_process = get_qwen_image_edit_plus_pre_process_func(od_config)
image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
request = SimpleNamespace(
prompts=[
Expand All @@ -34,5 +37,23 @@ def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path):
sampling_params=SimpleNamespace(height=None, width=None),
)

with pytest.raises(ValueError, match=r"At most 4 images are supported by this model"):
with pytest.raises(OmniInputValidationError, match=r"At most 4 images are supported"):
pre_process(request)


def test_qwen_image_edit_plus_rejects_images_exceeding_config_limit(tmp_path: Path):
od_config = _make_od_config(tmp_path, max_multimodal_image_inputs=2)
pre_process = get_qwen_image_edit_plus_pre_process_func(od_config)
image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
request = SimpleNamespace(
prompts=[
{
"prompt": "combine",
"multi_modal_data": {"image": [image, image, image]},
}
],
sampling_params=SimpleNamespace(height=None, width=None),
)

with pytest.raises(OmniInputValidationError, match=r"At most 2 images are supported"):
pre_process(request)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_layered import (
QwenImageLayeredPipeline,
)
from vllm_omni.exceptions import OmniInputValidationError

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

Expand Down Expand Up @@ -258,3 +259,49 @@ def test_qwen_edit_validator_excludes_image_placeholders_from_budget(pipeline_cl
)
def test_forward_max_sequence_length_default_is_1024(pipeline_class: type):
assert inspect.signature(pipeline_class.forward).parameters["max_sequence_length"].default == 1024


def test_forward_caps_max_sequence_length_with_max_multimodal_text_tokens():
pipeline = object.__new__(QwenImageEditPlusPipeline)
nn.Module.__init__(pipeline)
pipeline.device = torch.device("cpu")
pipeline.text_encoder = _RejectingTextEncoder()
pipeline.tokenizer_max_length = 1024
pipeline.prompt_template_encode = "{}"
pipeline.prompt_template_encode_start_idx = 64
pipeline.tokenizer = _FakeTokenizer([17, 0])
pipeline.processor = _FakeProcessor(17)
pipeline.vae_scale_factor = 8
pipeline.od_config = SimpleNamespace(max_multimodal_text_tokens=16, max_multimodal_image_inputs=None)
pipeline.check_cfg_parallel_validity = lambda *_args: True

req = SimpleNamespace(
prompts=[
{
"prompt": "a long prompt",
"additional_information": {
"condition_images": [],
"vae_images": [],
"condition_image_sizes": [],
"vae_image_sizes": [],
"calculated_height": 1024,
"calculated_width": 1024,
},
}
],
sampling_params=SimpleNamespace(
height=1024,
width=1024,
num_inference_steps=1,
sigmas=None,
max_sequence_length=None,
generator=None,
true_cfg_scale=1.0,
guidance_scale_provided=False,
guidance_scale=1.0,
num_outputs_per_prompt=0,
),
)

with pytest.raises(OmniInputValidationError, match=r"got 17 tokens"):
pipeline.forward(req)
10 changes: 9 additions & 1 deletion vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ class OmniDiffusionConfig:
supports_multimodal_inputs: bool = False
max_multimodal_image_inputs: int | None = None

max_multimodal_text_tokens: int | None = None

log_level: str = "info"

# Omni configuration (injected from stage config)
Expand Down Expand Up @@ -686,9 +688,14 @@ def set_tf_model_config(self, tf_config: "TransformerConfig") -> None:
def update_multimodal_support(self) -> None:
# Resolve serving-visible multimodal behavior from shared metadata
# instead of importing concrete pipeline modules into the config layer.
# User-supplied values (from stage config) take precedence over
# model defaults.
metadata = get_diffusion_model_metadata(self.model_class_name)
self.supports_multimodal_inputs = metadata.supports_multimodal_inputs
self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs
if self.max_multimodal_image_inputs is None:
self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs
if self.max_multimodal_text_tokens is None:
self.max_multimodal_text_tokens = metadata.max_multimodal_text_tokens

def enrich_config(self) -> None:
"""Load model metadata from HuggingFace and populate config fields.
Expand Down Expand Up @@ -796,6 +803,7 @@ class DiffusionOutput:
trajectory_log_probs: torch.Tensor | dict | None = None
trajectory_decoded: list[Image.Image] | None = None
error: str | None = None
error_type: str | None = None
aborted: bool = False
abort_message: str | None = None

Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus
req_id=sched_req_id,
step_index=None,
finished=True,
result=DiffusionOutput(error=str(exc)),
result=DiffusionOutput(error=str(exc), error_type=type(exc).__name__),
Comment on lines 367 to +370
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve validation exception type through DiffusionEngine.step

This change stores error_type in DiffusionOutput, but DiffusionEngine.step() still collapses any output.error into RuntimeError (raise RuntimeError(output.error)), so input-validation failures raised in pipeline forward() (including the new OmniInputValidationError prompt-length path) are surfaced upstream as RuntimeError. Because _check_engine_output_error only converts error_type == "OmniInputValidationError" into a 400, these requests still return 500 instead of the intended 400. Please preserve/re-raise the original type from output.error_type when output.error is present.

Useful? React with 👍 / 👎.

)

self._process_aborts_queue()
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/diffusion/inline_stage_diffusion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def _dispatch_request(
images=[],
)
error_output.error = str(e)
error_output.error_type = type(e).__name__
self._output_queue.put_nowait(error_output)
finally:
self._tasks.pop(request_id, None)
Expand Down Expand Up @@ -223,6 +224,7 @@ async def _dispatch_batch(
images=[],
)
error_output.error = str(e)
error_output.error_type = type(e).__name__
self._output_queue.put_nowait(error_output)
finally:
self._tasks.pop(request_id, None)
Expand Down
18 changes: 18 additions & 0 deletions vllm_omni/diffusion/model_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,33 @@ class DiffusionModelMetadata:
# config/model plumbing can read it without importing concrete pipelines.
supports_multimodal_inputs: bool = False
max_multimodal_image_inputs: int | None = None
max_multimodal_text_tokens: int | None = None


QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES = 4


_DIFFUSION_MODEL_METADATA: dict[str, DiffusionModelMetadata] = {
"Flux2Pipeline": DiffusionModelMetadata(
max_multimodal_text_tokens=512,
),
"QwenImagePipeline": DiffusionModelMetadata(
max_multimodal_text_tokens=1024,
),
"QwenImageLayeredPipeline": DiffusionModelMetadata(
supports_multimodal_inputs=True,
max_multimodal_image_inputs=1,
max_multimodal_text_tokens=1024,
),
"QwenImageEditPipeline": DiffusionModelMetadata(
supports_multimodal_inputs=True,
max_multimodal_image_inputs=1,
max_multimodal_text_tokens=1024,
),
"QwenImageEditPlusPipeline": DiffusionModelMetadata(
supports_multimodal_inputs=True,
max_multimodal_image_inputs=QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES,
max_multimodal_text_tokens=1024,
),
}

Expand Down
9 changes: 6 additions & 3 deletions vllm_omni/diffusion/models/flux2/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.exceptions import OmniInputValidationError
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,15 +68,17 @@ def check_image_input(
max_area: int = 1024 * 1024,
) -> PIL.Image.Image:
if not isinstance(image, PIL.Image.Image):
raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
raise OmniInputValidationError(f"Image must be a PIL.Image.Image, got {type(image)}")

width, height = image.size
if width < min_side_length or height < min_side_length:
raise ValueError(f"Image too small: {width}x{height}. Both dimensions must be at least {min_side_length}px")
raise OmniInputValidationError(
f"Image too small: {width}x{height}. Both dimensions must be at least {min_side_length}px"
)

aspect_ratio = max(width / height, height / width)
if aspect_ratio > max_aspect_ratio:
raise ValueError(
raise OmniInputValidationError(
f"Aspect ratio too extreme: {width}x{height} (ratio: {aspect_ratio:.1f}:1). "
f"Maximum allowed ratio is {max_aspect_ratio}:1"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm_omni.diffusion.utils.prompt_utils import (
validate_prompt_sequence_lengths,
)
from vllm_omni.exceptions import OmniInputValidationError
from vllm_omni.diffusion.utils.size_utils import (
normalize_min_aligned_size,
)
Expand Down Expand Up @@ -83,11 +84,11 @@ def pre_process_func(

# Only handles single image
if not raw_image: # None or empty list
raise ValueError("""Received no input image. This model requires one input image to run.""")
raise OmniInputValidationError("Received no input image. This model requires one input image to run.")
elif isinstance(raw_image, list):
if len(raw_image) > 1:
raise ValueError(
"""Received multiple input images. Only a single image is supported by this model."""
raise OmniInputValidationError(
"Received multiple input images. Only a single image is supported by this model."
)
else:
raw_image = raw_image[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from vllm_omni.diffusion.utils.prompt_utils import (
validate_prompt_sequence_lengths,
)
from vllm_omni.exceptions import OmniInputValidationError
from vllm_omni.diffusion.utils.size_utils import (
normalize_min_aligned_size,
)
Expand Down Expand Up @@ -100,10 +101,11 @@

if not isinstance(raw_image, list):
raw_image = [raw_image]
if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES:
raise ValueError(
max_images = od_config.max_multimodal_image_inputs if od_config.max_multimodal_image_inputs is not None else MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES

Check failure on line 104 in vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py:104:121: E501 Line too long (158 > 120)
if len(raw_image) > max_images:
raise OmniInputValidationError(
f"Received {len(raw_image)} input images. "
f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model."
f"At most {max_images} images are supported by this model."
)
image = [
PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im)
Expand Down Expand Up @@ -677,6 +679,8 @@
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
sigmas = req.sampling_params.sigmas or sigmas
max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length
if self.od_config.max_multimodal_text_tokens is not None:
max_sequence_length = min(max_sequence_length, self.od_config.max_multimodal_text_tokens)
generator = req.sampling_params.generator or generator
true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale
if req.sampling_params.guidance_scale_provided:
Expand Down Expand Up @@ -726,25 +730,31 @@
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt)

prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
image=condition_images, # Use condition images for prompt encoding
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)

if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt=negative_prompt,
image=condition_images, # Use same condition images for negative prompt encoding
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
try:
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
image=condition_images, # Use condition images for prompt encoding
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
prompt_name="negative_prompt",
)
except ValueError as exc:
raise OmniInputValidationError(str(exc)) from exc

if do_true_cfg:
try:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt=negative_prompt,
image=condition_images, # Use same condition images for negative prompt encoding
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
prompt_name="negative_prompt",
)
except ValueError as exc:
raise OmniInputValidationError(str(exc)) from exc

num_channels_latents = self.transformer.in_channels // 4
# random noise latents, and image latents encoded by vae
Expand Down
4 changes: 3 additions & 1 deletion vllm_omni/diffusion/stage_diffusion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def _drain_responses(self) -> None:
# Route request errors as error outputs so the Orchestrator
# sees the request complete (instead of hanging forever).
if req_id is not None:
self._output_queue.put_nowait(OmniRequestOutput.from_error(req_id, error_msg))
self._output_queue.put_nowait(OmniRequestOutput.from_error(
req_id, error_msg, error_type=msg.get("error_type"),
))

# Fields that are subprocess-local and cannot be serialized across
# process boundaries. They are recreated in the subprocess with
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/diffusion/stage_diffusion_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ async def _dispatch_request(
"type": "error",
"request_id": request_id,
"error": str(e),
"error_type": type(e).__name__,
}
)
)
Expand Down Expand Up @@ -396,6 +397,7 @@ async def _dispatch_batch(
"type": "error",
"request_id": rid,
"error": str(e),
"error_type": type(e).__name__,
}
)
)
Expand Down
7 changes: 5 additions & 2 deletions vllm_omni/diffusion/worker/diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,10 @@ def worker_busy_loop(self) -> None:
except Exception as e:
logger.error(f"Error processing RPC: {e}", exc_info=True)
if self.result_mq is not None:
self.return_result(DiffusionOutput(error=str(e)))
self.return_result(DiffusionOutput(
error=str(e),
error_type=type(e).__name__,
))

elif isinstance(msg, dict) and msg.get("type") == "shutdown":
logger.info("Worker %s: Received shutdown message", self.gpu_id)
Expand All @@ -647,7 +650,7 @@ def worker_busy_loop(self) -> None:
f"Error executing forward in event loop: {e}",
exc_info=True,
)
output = DiffusionOutput(error=str(e))
output = DiffusionOutput(error=str(e), error_type=type(e).__name__)

try:
self.return_result(output)
Expand Down
Loading
Loading