diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index 01e0bfba9d2..a91601b7ec8 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -557,10 +557,10 @@ def diffuse( def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] = None, - prompt_2: str | list[str] = None, - negative_prompt: str | list[str] = None, - negative_prompt_2: str | list[str] = None, + prompt: str | list[str] | None = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, true_cfg_scale: float = 1.0, height: int | None = None, width: int | None = None, @@ -581,16 +581,27 @@ def forward( max_sequence_length: int = 512, ): """Forward pass for flux.""" - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt - height = req.height or self.default_sample_size * self.vae_scale_factor - width = req.width or self.default_sample_size * self.vae_scale_factor - sigmas = req.sigmas or sigmas - num_inference_steps = req.num_inference_steps or num_inference_steps - generator = req.generator or generator - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) # 1. Check inputs. Raise error if not correct self.check_inputs(