diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md index 818bd15b736..ce6c7afa51d 100644 --- a/docs/contributing/model/adding_diffusion_model.md +++ b/docs/contributing/model/adding_diffusion_model.md @@ -337,6 +337,27 @@ See some parameters in `OmniDiffusionSamplingParams` as follows: **Extract parameters from request:** +The `OmniDiffusionRequest` object primarily contains two parts. + +1. **`prompt`**: a list of pure-string or multimodal prompt. It matches the [data structure of vLLM](https://docs.vllm.ai/en/stable/features/multimodal_inputs/#image-inputs). Each prompt in the list can be a string or a TypedDict. The dict version allows image input at `["multi_modal_data"]["images"]` and negative prompt at `["negative_prompt"]`. + - If your model requires a preprocess function, then the intermediate preprocessed values can be stored at the `["additional_information"]` field of a TypedDict prompt. + - If your model does not support batched input request, you can check the length of `req.prompts` and complain about the input to the user. In this case, the user is encouraged to request the prompts one-by-one. + - For example, an image editing model may expect the `prompt` to be something like this: + ```python + [ + { + "prompt": "turn this cat to a dog", + "multi_modal_data": {"image": input_image} + }, + ] + ``` + +2. **`sampling_params`**: a collection of common sampling parameters. Check the definition of [`OmniDiffusionSamplingParams`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/inputs/data/#vllm_omni.inputs.data.OmniDiffusionSamplingParams) dataclass for their default values. + - If your model requires a less-common sampling parameter, you can read it from the `["extra_args"]` field of the dataclass. To ensure user experience, you may want to document the list of extra args that your pipeline honors. + - If you believe a sampling parameter is common enough to be included in the `OmniDiffusionSamplingParams` dataclass, feel free to open an issue or clarify it in your PR that adds your model. + +Below is an example way to extract the prompt strings and sampling parameters from the `OmniDiffusionRequest`. + ```python from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.data import DiffusionOutput @@ -366,14 +387,6 @@ def forward( # ... rest of generation logic ``` -For an image editing model, an example `OmniDiffusionRequest` is like: -```python -{ - "prompt": "turn this cat to a dog", - "multi_modal_data": {"image": input_image} -}, -``` - **Wrap output:** ```diff diff --git a/docs/design/feature/cfg_parallel.md b/docs/design/feature/cfg_parallel.md index e31e64eddd4..779d25c406f 100644 --- a/docs/design/feature/cfg_parallel.md +++ b/docs/design/feature/cfg_parallel.md @@ -113,14 +113,7 @@ Call `self.diffuse` in your pipeline's forward function: ```python import torch.nn as nn class YourModelPipeline(nn.Module, CFGParallelMixin): - def forward( - self, - prompt: str, - negative_prompt: str | None = None, - guidance_scale: float = 3.5, - num_inference_steps: int = 50, - **kwargs, - ): + def forward(self, req: OmniDiffusionRequest): # Encode prompts, Initialize latents, Get timesteps ... # Run diffusion loop (calls the mixin's diffuse method) diff --git a/docs/features/custom_pipeline.md b/docs/features/custom_pipeline.md index aee5cfa6bee..44b5b1f0262 100644 --- a/docs/features/custom_pipeline.md +++ b/docs/features/custom_pipeline.md @@ -62,12 +62,12 @@ class CustomPipeline(QwenImageEditPipeline): def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): super().__init__(od_config=od_config, prefix=prefix) - def forward(self, req, prompt=None, negative_prompt=None, **kwargs): + def forward(self, req): # Call parent's forward to get normal output - output = super().forward(req=req, prompt=prompt, negative_prompt=negative_prompt, **kwargs) + output = super().forward(req=req) # Add custom trajectory data - actual_num_steps = req.sampling_params.num_inference_steps or kwargs.get('num_inference_steps', 50) + actual_num_steps = req.sampling_params.num_inference_steps or 50 output.trajectory_timesteps = torch.linspace(1000, 0, actual_num_steps, dtype=torch.float32) output.trajectory_latents = torch.randn(actual_num_steps, 1, 16, 64, 64, dtype=torch.float32) diff --git a/examples/offline_inference/custom_pipeline/image_to_image/custom_pipeline.py b/examples/offline_inference/custom_pipeline/image_to_image/custom_pipeline.py index b1c74e2b345..1e5be8631ac 100644 --- a/examples/offline_inference/custom_pipeline/image_to_image/custom_pipeline.py +++ b/examples/offline_inference/custom_pipeline/image_to_image/custom_pipeline.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -from typing import Any -import PIL.Image import torch from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig @@ -18,58 +16,13 @@ class CustomPipeline(QwenImageEditPipeline): def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): super().__init__(od_config=od_config, prefix=prefix) - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - image: PIL.Image.Image | torch.Tensor | None = None, - true_cfg_scale: float = 4.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 1.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds_mask: torch.Tensor | None = None, - output_type: str | None = "pil", - attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: """Forward pass for image editing with dummy trajectory data.""" # Call parent's forward to get the normal output - output = super().forward( - req=req, - prompt=prompt, - negative_prompt=negative_prompt, - image=image, - true_cfg_scale=true_cfg_scale, - height=height, - width=width, - num_inference_steps=num_inference_steps, - sigmas=sigmas, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - generator=generator, - latents=latents, - prompt_embeds=prompt_embeds, - prompt_embeds_mask=prompt_embeds_mask, - negative_prompt_embeds=negative_prompt_embeds, - negative_prompt_embeds_mask=negative_prompt_embeds_mask, - output_type=output_type, - attention_kwargs=attention_kwargs, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) + output = super().forward(req=req) # Get actual num_inference_steps used - actual_num_steps = req.sampling_params.num_inference_steps or num_inference_steps + actual_num_steps = req.sampling_params.num_inference_steps or 50 # Create dummy trajectory data dummy_trajectory_latents = torch.randn(actual_num_steps, 1, 16, 64, 64, dtype=torch.float32) diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index 3955fee120e..3210c126966 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -583,54 +583,76 @@ def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: boo return False return True - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - prompt_2: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - negative_prompt_2: str | list[str] | None = None, - true_cfg_scale: float = 1.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 28, - sigmas: list[float] | None = None, - guidance_scale: float = 3.5, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: torch.FloatTensor | None = None, - pooled_prompt_embeds: torch.FloatTensor | None = None, - negative_prompt_embeds: torch.FloatTensor | None = None, - negative_pooled_prompt_embeds: torch.FloatTensor | None = None, - output_type: str | None = "pil", - return_dict: bool = True, - joint_attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - ): + def forward(self, req: OmniDiffusionRequest): """Forward pass for flux.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] + + # For negative prompt, make it None if ALL are None---making it falsy and skipping CFG + # If only some of them are not None, only set those to empty strings---because we cannot skip CFG anyway. if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + try: + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `prompt_embeds` for at least one prompt, you have to provide `prompt_embeds` for" + " all prompts so the pipeline can stack them together." + ) + else: + prompt_embeds = None + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + try: + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `negative_prompt_embeds` for at least one prompt, " + "you have to provide `negative_prompt_embeds` for all prompts " + "so the pipeline can stack them together." + ) + else: + negative_prompt_embeds = None + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + num_inference_steps = req.sampling_params.num_inference_steps or 28 + sigmas = req.sampling_params.sigmas + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 3.5 + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 1.0 + max_sequence_length = req.sampling_params.max_sequence_length or 512 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + + prompt_2: str | list[str] | None = req.sampling_params.extra_args.get("prompt_2", None) + negative_prompt_2: str | list[str] | None = req.sampling_params.extra_args.get("negative_prompt_2", None) + pooled_prompt_embeds: torch.FloatTensor | None = req.sampling_params.extra_args.get( + "pooled_prompt_embeds", None + ) + negative_pooled_prompt_embeds: torch.FloatTensor | None = req.sampling_params.extra_args.get( + "negative_pooled_prompt_embeds", None + ) + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + joint_attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get( + "joint_attention_kwargs", None + ) + callback_on_step_end_tensor_inputs: list[str] = req.sampling_params.extra_args.get( + "callback_on_step_end_tensor_inputs", ["latents"] ) # 1. Check inputs. Raise error if not correct diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 551be4f069f..7dc0f13e754 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -647,96 +647,33 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def forward( - self, - req: OmniDiffusionRequest, - image: PIL.Image.Image | list[PIL.Image.Image] | None = None, - prompt: str | list[str] | None = None, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float | None = 4.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - output_type: str | None = "pil", - return_dict: bool = True, - attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end: Callable[[int, int, dict], None] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. Args: - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list of these): - `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, - `guidance_scale` is ignored. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. - If not provided, will be generated from "". - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - text_encoder_out_layers (`Tuple[int]`): - Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + req (`OmniDiffusionRequest`): + The request object containing the prompts and sampling parameters. + The `req.sampling_params.extra_args` can include the following keys: + - output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + - attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + - callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is + called with the following arguments: + `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. + `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + - callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. + - text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. Examples: @@ -753,46 +690,46 @@ def forward( first_prompt = req.prompts[0] prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + if ( raw_image := None if isinstance(first_prompt, str) else first_prompt.get("multi_modal_data", {}).get("image") ) is None: - pass # use image from param list + image = None elif isinstance(raw_image, list): image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] else: image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - generator = req.sampling_params.generator or generator + height = req.sampling_params.height + width = req.sampling_params.width + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 4.0 + generator = req.sampling_params.generator num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + max_sequence_length = req.sampling_params.max_sequence_length or 512 + latents = req.sampling_params.latents + + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get("attention_kwargs", None) + callback_on_step_end: Callable[[int, int, dict], None] | None = req.sampling_params.extra_args.get( + "callback_on_step_end", None + ) + callback_on_step_end_tensor_inputs: list[str] = req.sampling_params.extra_args.get( + "callback_on_step_end_tensor_inputs", ["latents"] + ) + text_encoder_out_layers: tuple[int, ...] = req.sampling_params.extra_args.get( + "text_encoder_out_layers", (9, 18, 27) ) - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers) - - req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] - if any(p is not None for p in req_prompt_embeds): - # If at list one prompt is provided as an embedding, - # Then assume that the user wants to provide embeddings for all prompts, and enter this if block - # If the user in fact provides mixed input format, req_prompt_embeds will have some None's - # And `torch.stack` automatically raises an exception for us - prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError - - req_negative_prompt_embeds = [ - p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts - ] - if any(p is not None for p in req_negative_prompt_embeds): - negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError # 1. Check inputs. Raise error if not correct self.check_inputs( diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py index 8fca5069414..d0d4b44525b 100644 --- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -719,7 +719,10 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: height = req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor num_inference_steps = req.sampling_params.num_inference_steps or 50 - guidance_scale = req.sampling_params.guidance_scale or 1.5 + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.5 self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds) diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py index e4b717a6979..8999193fd93 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py @@ -961,29 +961,22 @@ def forward_call( return output - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] = "", - image_size="auto", - height: int = 1024, - width: int = 1024, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - system_prompt: str | None = None, - generator: torch.Generator | list[torch.Generator] | None = None, - **kwargs, - ) -> DiffusionOutput: - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt - generator = req.sampling_params.generator or generator - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] + generator = req.sampling_params.generator + height = req.sampling_params.height or 1024 + width = req.sampling_params.width or 1024 + num_inference_steps = req.sampling_params.num_inference_steps or 50 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 5.0 if guidance_scale <= 1.0: logger.warning("HunyuanImage3.0 does not support guidance_scale <= 1.0, will set it to 1.0 + epsilon.") guidance_scale = 1.0 + np.finfo(float).eps + + system_prompt: str | None = req.sampling_params.extra_args.get("system_prompt") + image_size = (height, width) model_inputs = self.prepare_model_inputs( prompt=prompt, @@ -995,5 +988,5 @@ def forward( num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) - outputs = self._generate(**model_inputs, **kwargs) + outputs = self._generate(**model_inputs) return DiffusionOutput(output=outputs[0]) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 09f409f3139..c2f451695a2 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -470,66 +470,62 @@ def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0): noise_pred = comb_pred * scale return noise_pred - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 4.5, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - output_type: str | None = "pil", - return_dict: bool = True, - joint_attention_kwargs: dict[str, Any] | None = None, - enable_cfg_renorm: bool | None = True, - cfg_renorm_min: float | None = 0.0, - enable_prompt_rewrite: bool | None = True, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - height = req.sampling_params.height or height or self.default_sample_size * self.vae_scale_factor - width = req.sampling_params.width or width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - generator = req.sampling_params.generator or generator - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt is not None - else num_images_per_prompt - ) - enable_prompt_rewrite = req.sampling_params.extra_args.get("enable_prompt_rewrite", enable_prompt_rewrite) - enable_cfg_renorm = req.sampling_params.extra_args.get("enable_cfg_renorm", enable_cfg_renorm) - cfg_renorm_min = req.sampling_params.extra_args.get("cfg_renorm_min", cfg_renorm_min) - req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] if any(p is not None for p in req_prompt_embeds): - # If at list one prompt is provided as an embedding, - # Then assume that the user wants to provide embeddings for all prompts, and enter this if block - # If the user in fact provides mixed input format, req_prompt_embeds will have some None's - # And `torch.stack` automatically raises an exception for us - prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + try: + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `prompt_embeds` for at least one prompt, you have to provide `prompt_embeds` for" + " all prompts so the pipeline can stack them together." + ) + else: + prompt_embeds = None req_negative_prompt_embeds = [ p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts ] if any(p is not None for p in req_negative_prompt_embeds): - negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + try: + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `negative_prompt_embeds` for at least one prompt, " + "you have to provide `negative_prompt_embeds` for all prompts " + "so the pipeline can stack them together." + ) + else: + negative_prompt_embeds = None + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + generator = req.sampling_params.generator + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 4.5 + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + enable_prompt_rewrite: bool = req.sampling_params.extra_args.get("enable_prompt_rewrite", True) + enable_cfg_renorm: bool = req.sampling_params.extra_args.get("enable_cfg_renorm", True) + cfg_renorm_min: float = req.sampling_params.extra_args.get("cfg_renorm_min", 0.0) + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + joint_attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get( + "joint_attention_kwargs", None + ) self.check_inputs( prompt, diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 3ba5e344883..4275fb97e36 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -6,7 +6,7 @@ import os import re from collections.abc import Iterable -from typing import Any, cast +from typing import cast import numpy as np import PIL.Image @@ -519,24 +519,7 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - def forward( - self, - req: OmniDiffusionRequest, - image: PIL.Image.Image | torch.Tensor | None = None, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 3.5, - num_images_per_prompt: int | None = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - output_type: str | None = "pil", - return_dict: bool = True, - joint_attention_kwargs: dict[str, Any] | None = None, - ): + def forward(self, req: OmniDiffusionRequest): # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. if len(req.prompts) > 1: @@ -548,21 +531,23 @@ def forward( prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") - negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") # type: ignore # Why it is list[torch.Tensor] in OmniTokenInputs or OmniEmbedsPrompt? Doesn't make sense + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") - sigmas = req.sampling_params.sigmas or sigmas - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 3.5 + num_inference_steps = req.sampling_params.num_inference_steps or 50 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt is not None - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + latents = req.sampling_params.latents + + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") if prompt is not None: batch_size = 1 if isinstance(prompt, str) else len(prompt) @@ -572,11 +557,24 @@ def forward( if not isinstance(first_prompt, str) and "preprocessed_image" in ( additional_information := first_prompt.get("additional_information", {}) ): + # Using preprocessed image prompt_image = additional_information.get("prompt_image") image = additional_information.get("preprocessed_image") calculated_height = additional_information.get("calculated_height", height) calculated_width = additional_information.get("calculated_width", width) else: + # Using original image + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + image = None + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + image_size = image[0].size if isinstance(image, list) else image.size calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 963f1c483b3..d128263eec4 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -18,7 +18,7 @@ import inspect import json import os -from collections.abc import Callable, Iterable +from collections.abc import Iterable from typing import Any import numpy as np @@ -524,88 +524,25 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - guidance_scale: float = 5.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - num_images_per_prompt: int | None = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: torch.FloatTensor | None = None, - negative_prompt_embeds: torch.FloatTensor | None = None, - output_type: str | None = "pil", - return_dict: bool = True, - joint_attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end: Callable[[int, int, dict], None] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 256, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `list[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - negative_prompt (`str` or `list[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - not greater than `1`). - guidance_scale (`float`, *optional*, defaults to 1.0): - True classifier-free guidance (guidance scale) is enabled when `guidance_scale` > 1 and - `negative_prompt` is provided. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - sigmas (`list[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `list[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`list`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + req (`OmniDiffusionRequest`): + The request object containing the prompts and sampling parameters. + The `req.sampling_params.extra_args` can include the following keys: + - output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + - joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + - callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. Examples: @@ -616,24 +553,60 @@ def forward( """ # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + try: + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `prompt_embeds` for at least one prompt, you have to provide `prompt_embeds` for" + " all prompts so the pipeline can stack them together." + ) + else: + prompt_embeds = None + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + try: + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `negative_prompt_embeds` for at least one prompt, " + "you have to provide `negative_prompt_embeds` for all prompts " + "so the pipeline can stack them together." + ) + else: + negative_prompt_embeds = None + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - generator = req.sampling_params.generator or generator + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 5.0 + generator = req.sampling_params.generator num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + max_sequence_length = req.sampling_params.max_sequence_length or 256 + + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + joint_attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get( + "joint_attention_kwargs", None + ) + callback_on_step_end_tensor_inputs: list[str] = req.sampling_params.extra_args.get( + "callback_on_step_end_tensor_inputs", ["latents"] ) # Steps: diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 02758ee19ab..096b1dc2630 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -534,51 +534,68 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - true_cfg_scale: float = 4.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 1.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds_mask: torch.Tensor | None = None, - output_type: str | None = "pil", - attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + try: + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `prompt_embeds` for at least one prompt, you have to provide `prompt_embeds` for" + " all prompts so the pipeline can stack them together." + ) + else: + prompt_embeds = None + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + try: + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `negative_prompt_embeds` for at least one prompt, " + "you have to provide `negative_prompt_embeds` for all prompts " + "so the pipeline can stack them together." + ) + else: + negative_prompt_embeds = None + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 4.0 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents + + prompt_embeds_mask: torch.Tensor | None = req.sampling_params.extra_args.get("prompt_embeds_mask", None) + negative_prompt_embeds_mask: torch.Tensor | None = req.sampling_params.extra_args.get( + "negative_prompt_embeds_mask", None + ) + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get("attention_kwargs", None) + callback_on_step_end_tensor_inputs: list[str] = req.sampling_params.extra_args.get( + "callback_on_step_end_tensor_inputs", ["latents"] + ) + # 1. check inputs # 2. encode prompts # 3. prepare latents and timesteps diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index bff248a444b..dc3c58c3f05 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -597,30 +597,7 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - image: PIL.Image.Image | torch.Tensor | None = None, - true_cfg_scale: float = 4.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 1.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds_mask: torch.Tensor | None = None, - output_type: str | None = "pil", - attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: """Forward pass for image editing.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. @@ -639,6 +616,8 @@ def forward( "Qwen official repository recommends to use whitespace string as negative_prompt. " "Note: some distilled variants may not be affected by this." ) + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") # Get preprocessed image from request (pre-processing is done in DiffusionEngine) if not isinstance(first_prompt, str) and "preprocessed_image" in ( @@ -652,10 +631,21 @@ def forward( width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + image = None + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + image_size = image[0].size if isinstance(image, list) else image.size calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) - height = height or calculated_height - width = width or calculated_width + height = req.sampling_params.height or calculated_height + width = req.sampling_params.width or calculated_width multiple_of = self.vae_scale_factor * 2 width = width // multiple_of * multiple_of @@ -667,17 +657,28 @@ def forward( image = self.image_processor.preprocess(image, calculated_height, calculated_width) image = image.unsqueeze(2) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 4.0 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + + prompt_embeds_mask: torch.Tensor | None = req.sampling_params.extra_args.get("prompt_embeds_mask", None) + negative_prompt_embeds_mask: torch.Tensor | None = req.sampling_params.extra_args.get( + "negative_prompt_embeds_mask", None + ) + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get("attention_kwargs", None) + callback_on_step_end_tensor_inputs: list[str] = req.sampling_params.extra_args.get( + "callback_on_step_end_tensor_inputs", ["latents"] ) # 1. check inputs diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 2fdc7003efd..06c6049233e 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -528,30 +528,7 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor | None = None, - true_cfg_scale: float = 4.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 1.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds_mask: torch.Tensor | None = None, - output_type: str | None = "pil", - attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: """Forward pass for image editing with support for multiple images.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. @@ -570,6 +547,8 @@ def forward( "Qwen official repository recommends to use whitespace string as negative_prompt. " "Note: some distilled variants may not be affected by this." ) + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") # Get preprocessed images from request (pre-processing is done in DiffusionEngine) if ( @@ -587,16 +566,21 @@ def forward( width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) - if image is None: + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: raise ValueError("Image is required for QwenImageEditPlusPipeline") - - if not isinstance(image, list): - image = [image] + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = [PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image)] image_size = image[0].size calculated_width, calculated_height = calculate_dimensions(VAE_IMAGE_SIZE, image_size[0] / image_size[1]) - height = height or calculated_height - width = width or calculated_width + height = req.sampling_params.height or calculated_height + width = req.sampling_params.width or calculated_width multiple_of = self.vae_scale_factor * 2 width = width // multiple_of * multiple_of @@ -618,17 +602,28 @@ def forward( condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 4.0 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + + prompt_embeds_mask: torch.Tensor | None = req.sampling_params.extra_args.get("prompt_embeds_mask", None) + negative_prompt_embeds_mask: torch.Tensor | None = req.sampling_params.extra_args.get( + "negative_prompt_embeds_mask", None + ) + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get("attention_kwargs", None) + callback_on_step_end_tensor_inputs: list[str] = req.sampling_params.extra_args.get( + "callback_on_step_end_tensor_inputs", ["latents"] ) # 1. check inputs diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index dbe0bfe4f85..9ef6708a28d 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -575,31 +575,7 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def forward( - self, - req: OmniDiffusionRequest, - image: PIL.Image.Image | torch.Tensor | None = None, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - true_cfg_scale: float = 4.0, - layers: int | None = 4, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float | None = None, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds_mask: torch.Tensor | None = None, - output_type: str | None = "pil", - attention_kwargs: dict[str, Any] | None = None, - max_sequence_length: int = 512, - resolution: int = 640, - cfg_normalize: bool = False, - use_en_prompt: bool = False, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: """Forward pass for image layered.""" # 1. Get preprocessed image from request (pre-processing is done in DiffusionEngine) @@ -614,27 +590,33 @@ def forward( first_prompt = req.prompts[0] prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") - - layers = req.sampling_params.layers if req.sampling_params.layers is not None else layers - resolution = req.sampling_params.resolution if req.sampling_params.resolution is not None else resolution - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - cfg_normalize = ( - req.sampling_params.cfg_normalize if req.sampling_params.cfg_normalize is not None else cfg_normalize - ) - use_en_prompt = ( - req.sampling_params.use_en_prompt if req.sampling_params.use_en_prompt is not None else use_en_prompt - ) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + + layers = req.sampling_params.layers or 4 + resolution = req.sampling_params.resolution or 640 + max_sequence_length = req.sampling_params.max_sequence_length or 512 + cfg_normalize = req.sampling_params.cfg_normalize + use_en_prompt = req.sampling_params.use_en_prompt + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 4.0 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + + prompt_embeds_mask: torch.Tensor | None = req.sampling_params.extra_args.get("prompt_embeds_mask", None) + negative_prompt_embeds_mask: torch.Tensor | None = req.sampling_params.extra_args.get( + "negative_prompt_embeds_mask", None ) + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get("attention_kwargs", None) if not isinstance(first_prompt, str) and "preprocessed_image" in ( additional_information := first_prompt.get("additional_information", {}) @@ -648,6 +630,17 @@ def forward( width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + image = None + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + image_size = image[0].size if isinstance(image, list) else image.size assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}" calculated_width, calculated_height = calculate_dimensions( diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 9c9722c368a..a744e580497 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -566,46 +566,59 @@ def diffuse( return latents - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] = "", - prompt_2: str | list[str] = "", - prompt_3: str | list[str] = "", - negative_prompt: str | list[str] = "", - negative_prompt_2: str | list[str] = "", - negative_prompt_3: str | list[str] = "", - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 28, - sigmas: list[float] | None = None, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - pooled_prompt_embeds: torch.Tensor | None = None, - negative_pooled_prompt_embeds: torch.Tensor | None = None, - max_sequence_length: int = 256, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt - negative_prompt = [ - "" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts - ] or negative_prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + try: + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `prompt_embeds` for at least one prompt, you have to provide `prompt_embeds` for" + " all prompts so the pipeline can stack them together." + ) + else: + prompt_embeds = None + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + try: + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `negative_prompt_embeds` for at least one prompt, " + "you have to provide `negative_prompt_embeds` for all prompts " + "so the pipeline can stack them together." + ) + else: + negative_prompt_embeds = None height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - generator = req.sampling_params.generator or generator + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 256 + num_inference_steps = req.sampling_params.num_inference_steps or 28 + generator = req.sampling_params.generator num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + + prompt_2: str | list[str] = req.sampling_params.extra_args.get("prompt_2", "") + prompt_3: str | list[str] = req.sampling_params.extra_args.get("prompt_3", "") + negative_prompt_2: str | list[str] = req.sampling_params.extra_args.get("negative_prompt_2", "") + negative_prompt_3: str | list[str] = req.sampling_params.extra_args.get("negative_prompt_3", "") + pooled_prompt_embeds: torch.Tensor | None = req.sampling_params.extra_args.get("pooled_prompt_embeds", None) + negative_pooled_prompt_embeds: torch.Tensor | None = req.sampling_params.extra_args.get( + "negative_pooled_prompt_embeds", None ) + # 1. check inputs # 2. encode prompts # 3. prepare latents and timesteps @@ -626,7 +639,10 @@ def forward( max_sequence_length=max_sequence_length, ) - self._guidance_scale = req.sampling_params.guidance_scale + if req.sampling_params.guidance_scale_provided: + self._guidance_scale = req.sampling_params.guidance_scale + else: + self._guidance_scale = 7.0 self._current_timestep = None self._interrupt = False diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index 22dfc06c5a8..599f3772979 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -349,39 +349,21 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - audio_end_in_s: float | None = None, - audio_start_in_s: float = 0.0, - num_inference_steps: int = 100, - guidance_scale: float = 7.0, - num_waveforms_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - output_type: str = "np", - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: """ Generate audio from text prompt. Args: - req: OmniDiffusionRequest containing generation parameters - prompt: Text prompt for audio generation - negative_prompt: Negative prompt for CFG - audio_end_in_s: Audio end time in seconds (max ~47s for stable-audio-open-1.0) - audio_start_in_s: Audio start time in seconds - num_inference_steps: Number of denoising steps - guidance_scale: CFG scale - num_waveforms_per_prompt: Number of audio outputs per prompt - generator: Random generator for reproducibility - latents: Pre-generated latents - prompt_embeds: Pre-computed prompt embeddings - negative_prompt_embeds: Pre-computed negative prompt embeddings - output_type: Output format ("np", "pt", or "latent") + req: OmniDiffusionRequest containing generation parameters. + The `req.sampling_params.extra_args` can include the following keys: + - audio_start_in_s (`float`, *optional*, defaults to 0.0): + Start time of the audio in seconds. + - audio_end_in_s (`float`, *optional*): + End time of the audio in seconds. + - num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + Number of audio outputs per prompt. + - output_type (`str`, *optional*, defaults to "np"): + Output format ("np", "pt", or "latent"). Returns: DiffusionOutput containing generated audio @@ -389,24 +371,55 @@ def forward( # Extract from request # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + try: + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `prompt_embeds` for at least one prompt, you have to provide `prompt_embeds` for" + " all prompts so the pipeline can stack them together." + ) + else: + prompt_embeds = None + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + try: + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `negative_prompt_embeds` for at least one prompt, " + "you have to provide `negative_prompt_embeds` for all prompts " + "so the pipeline can stack them together." + ) + else: + negative_prompt_embeds = None + + num_inference_steps = req.sampling_params.num_inference_steps or 100 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 7.0 - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) + latents = req.sampling_params.latents # Get audio duration from request extra params or defaults - audio_start_in_s = req.sampling_params.extra_args.get("audio_start_in_s", audio_start_in_s) - audio_end_in_s = req.sampling_params.extra_args.get("audio_end_in_s", audio_end_in_s) + audio_start_in_s: float = req.sampling_params.extra_args.get("audio_start_in_s", 0.0) + audio_end_in_s: float | None = req.sampling_params.extra_args.get("audio_end_in_s", None) + num_waveforms_per_prompt: int = req.sampling_params.extra_args.get("num_waveforms_per_prompt", 1) + output_type: str = req.sampling_params.extra_args.get("output_type", "np") # Calculate audio length downsample_ratio = self.vae.hop_length diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 9a0037a6a1b..b373bd625a6 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -327,23 +327,7 @@ def num_timesteps(self): def current_timestep(self): return self._current_timestep - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | None = None, - negative_prompt: str | None = None, - height: int = 480, - width: int = 832, - num_inference_steps: int = 40, - guidance_scale: float | tuple[float, float] = 4.0, - frame_num: int = 81, - output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - attention_kwargs: dict | None = None, - **kwargs, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Get parameters from request or arguments if len(req.prompts) > 1: raise ValueError( @@ -351,14 +335,20 @@ def forward( """Please pass in a single prompt object or string, or a single-item list.""", ) if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list - prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") - negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = ( + None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + ) + if prompt is None and prompt_embeds is None: raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num + height = req.sampling_params.height or 480 + width = req.sampling_params.width or 832 + num_frames = req.sampling_params.num_frames or 81 # Ensure dimensions are compatible with VAE and patch size # For expand_timesteps mode, we need latent dims to be even (divisible by patch_size) @@ -366,11 +356,13 @@ def forward( mod_value = self.vae_scale_factor_spatial * patch_size[1] # 16*2=32 for TI2V, 8*2=16 for I2V height = (height // mod_value) * mod_value width = (width // mod_value) * mod_value - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps or 40 # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 4.0 guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] guidance_high = ( @@ -383,6 +375,9 @@ def forward( ) ) + output_type: str = req.sampling_params.extra_args.get("output_type", "np") + attention_kwargs: dict | None = req.sampling_params.extra_args.get("attention_kwargs", None) + # record guidance for properties self._guidance_scale = guidance_low self._guidance_scale_2 = guidance_high @@ -421,8 +416,7 @@ def forward( dtype = self.text_encoder.dtype # Seed / generator - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index fe15a24f587..25cbb5b0df0 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -281,26 +281,7 @@ def encode_image( image_embeds = self.image_encoder(pixel_values, output_hidden_states=True) return image_embeds.hidden_states[-2] - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | None = None, - negative_prompt: str | None = None, - image: PIL.Image.Image | torch.Tensor | None = None, - height: int = 480, - width: int = 832, - num_inference_steps: int = 40, - guidance_scale: float | tuple[float, float] = 5.0, - frame_num: int = 81, - output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - image_embeds: torch.Tensor | None = None, - last_image: PIL.Image.Image | torch.Tensor | None = None, - attention_kwargs: dict | None = None, - **kwargs, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Get parameters from request or arguments if len(req.prompts) > 1: raise ValueError( @@ -308,39 +289,43 @@ def forward( """Please pass in a single prompt object or string, or a single-item list.""", ) if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list - prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") - negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = ( + None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + ) if prompt is None and prompt_embeds is None: raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") # Get image from request - if image is None: - multi_modal_data = ( - req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None - ) - raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None - if raw_image is None: - raise ValueError("Image is required for I2V generation.") - if isinstance(raw_image, list): - if len(raw_image) > 1: - logger.warning( - """Received a list of image. Only a single image is supported by this model.""" - """Taking only the first image for now.""" - ) - raw_image = raw_image[0] - if isinstance(raw_image, str): - image = PIL.Image.open(raw_image) - else: - image = cast(PIL.Image.Image | torch.Tensor, raw_image) + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if raw_image is None: + raise ValueError("Image is required for I2V generation.") + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_frames = req.sampling_params.num_frames or frame_num - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + height = req.sampling_params.height or 480 + width = req.sampling_params.width or 832 + num_frames = req.sampling_params.num_frames or 81 + num_steps = req.sampling_params.num_inference_steps or 40 # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 5.0 # Handle guidance scales guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] @@ -354,6 +339,11 @@ def forward( ) ) + output_type: str = req.sampling_params.extra_args.get("output_type", "np") + image_embeds: torch.Tensor | None = req.sampling_params.extra_args.get("image_embeds", None) + last_image: PIL.Image.Image | torch.Tensor | None = req.sampling_params.extra_args.get("last_image", None) + attention_kwargs: dict | None = req.sampling_params.extra_args.get("attention_kwargs", None) + self._guidance_scale = guidance_low self._guidance_scale_2 = guidance_high @@ -385,8 +375,7 @@ def forward( dtype = self.transformer.dtype # Generator setup - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index f116834cf28..0c15654a6dd 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -218,24 +218,7 @@ def num_timesteps(self): def current_timestep(self): return self._current_timestep - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | None = None, - negative_prompt: str | None = None, - image: PIL.Image.Image | torch.Tensor | None = None, - height: int = 704, - width: int = 1280, - num_inference_steps: int = 40, - guidance_scale: float = 5.0, - frame_num: int = 81, - output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - attention_kwargs: dict | None = None, - **kwargs, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Get parameters from request or arguments if len(req.prompts) > 1: raise ValueError( @@ -243,40 +226,47 @@ def forward( """Please pass in a single prompt object or string, or a single-item list.""", ) if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list - prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") - negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = ( + None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + ) if prompt is None and prompt_embeds is None: raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") # Get image from request (optional for TI2V) - if image is None: - multi_modal_data = ( - req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None - ) - raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None - if isinstance(raw_image, list): - if len(raw_image) > 1: - logger.warning( - """Received a list of image. Only a single image is supported by this model.""" - """Taking only the first image for now.""" - ) - raw_image = raw_image[0] - if raw_image is None: - image = None - elif isinstance(raw_image, str): - image = PIL.Image.open(raw_image) - else: - image = cast(PIL.Image.Image | torch.Tensor, raw_image) + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if raw_image is None: + image = None + elif isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) # Default dimensions for TI2V-5B (720P) - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + height = req.sampling_params.height or 704 + width = req.sampling_params.width or 1280 + num_frames = req.sampling_params.num_frames or 81 + num_steps = req.sampling_params.num_inference_steps or 40 # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 5.0 + + output_type: str = req.sampling_params.extra_args.get("output_type", "np") + attention_kwargs: dict | None = req.sampling_params.extra_args.get("attention_kwargs", None) self._guidance_scale = guidance_scale @@ -300,8 +290,7 @@ def forward( dtype = self.transformer.dtype # Generator setup - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index b16d24bf7b2..9409ae16aca 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -315,99 +315,35 @@ def num_timesteps(self): def interrupt(self): return self._interrupt - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - height: int = 1024, - width: int = 1024, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 5.0, - cfg_normalization: bool = False, - cfg_truncation: float = 1.0, - negative_prompt: str | list[str] | None = None, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: list[torch.FloatTensor] | None = None, - negative_prompt_embeds: list[torch.FloatTensor] | None = None, - output_type: str | None = "pil", - return_dict: bool = True, - joint_attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end: Callable[[int, int, dict], None] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - ) -> DiffusionOutput: + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `list[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, *optional*, defaults to 1024): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 1024): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - sigmas (`list[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - guidance_scale (`float`, *optional*, defaults to 5.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - cfg_normalization (`bool`, *optional*, defaults to False): - Whether to apply configuration normalization. - cfg_truncation (`float`, *optional*, defaults to 1.0): - The truncation value for configuration. - negative_prompt (`str` or `list[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `list[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`list[torch.FloatTensor]`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain - tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`list`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int`, *optional*, defaults to 512): - Maximum sequence length to use with the `prompt`. + req (`OmniDiffusionRequest`): + The request object containing the prompts and sampling parameters. + The `req.sampling_params.extra_args` can include the following keys: + - cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + - cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + - output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + - joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + - callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is + called with the following arguments: + `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. + `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + - callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. Examples: @@ -418,25 +354,68 @@ def forward( """ # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] + + # For negative prompt, make it None if ALL are None---making it falsy and skipping CFG + # If only some of them are not None, only set those to empty strings---because we cannot skip CFG anyway. if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): - negative_prompt = None - elif req.prompts: + negative_prompt: list[str] | None = None + else: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + try: + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `prompt_embeds` for at least one prompt, you have to provide `prompt_embeds` for" + " all prompts so the pipeline can stack them together." + ) + else: + prompt_embeds = None + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + try: + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + except TypeError: + raise ValueError( + "If you provide `negative_prompt_embeds` for at least one prompt, " + "you have to provide `negative_prompt_embeds` for all prompts " + "so the pipeline can stack them together." + ) + else: + negative_prompt_embeds = None + + height = req.sampling_params.height or 1024 + width = req.sampling_params.width or 1024 + num_inference_steps = req.sampling_params.num_inference_steps or 50 generator = req.sampling_params.generator - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_rescale is not None else guidance_scale - ) + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 5.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + + cfg_normalization: bool = req.sampling_params.extra_args.get("cfg_normalization", False) + cfg_truncation: float = req.sampling_params.extra_args.get("cfg_truncation", 1.0) + output_type: str = req.sampling_params.extra_args.get("output_type", "pil") + joint_attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get( + "joint_attention_kwargs", None + ) + callback_on_step_end: Callable[[int, int, dict], None] | None = req.sampling_params.extra_args.get( + "callback_on_step_end", None + ) + callback_on_step_end_tensor_inputs: list[str] = req.sampling_params.extra_args.get( + "callback_on_step_end_tensor_inputs", ["latents"] ) vae_scale = self.vae_scale_factor * 2