diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index c4155a9fc8b..aa4f0a74f02 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -387,7 +387,12 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: else: image_input = ( - None if isinstance(first_prompt, str) else (first_prompt.get("multi_modal_data") or {}).get("image") + None + if isinstance(first_prompt, str) + else ( + (first_prompt.get("multi_modal_data") or {}).get("image") + or (first_prompt.get("multi_modal_data") or {}).get("img2img") + ) ) if image_input and not isinstance(image_input, list): image_input = [image_input]