Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions docs/contributing/model/adding_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,27 @@ See some parameters in `OmniDiffusionSamplingParams` as follows:

**Extract parameters from request:**

The `OmniDiffusionRequest` object primarily contains two parts.

1. **`prompt`**: a list of pure-string or multimodal prompt. It matches the [data structure of vLLM](https://docs.vllm.ai/en/stable/features/multimodal_inputs/#image-inputs). Each prompt in the list can be a string or a TypedDict. The dict version allows image input at `["multi_modal_data"]["images"]` and negative prompt at `["negative_prompt"]`.
- If your model requires a preprocess function, then the intermediate preprocessed values can be stored at the `["additional_information"]` field of a TypedDict prompt.
- If your model does not support batched input request, you can check the length of `req.prompts` and complain about the input to the user. In this case, the user is encouraged to request the prompts one-by-one.
- For example, an image editing model may expect the `prompt` to be something like this:
```python
[
{
"prompt": "turn this cat to a dog",
"multi_modal_data": {"image": input_image}
},
]
```

2. **`sampling_params`**: a collection of common sampling parameters. Check the definition of [`OmniDiffusionSamplingParams`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/inputs/data/#vllm_omni.inputs.data.OmniDiffusionSamplingParams) dataclass for their default values.
- If your model requires a less-common sampling parameter, you can read it from the `["extra_args"]` field of the dataclass. To ensure user experience, you may want to document the list of extra args that your pipeline honors.
- If you believe a sampling parameter is common enough to be included in the `OmniDiffusionSamplingParams` dataclass, feel free to open an issue or clarify it in your PR that adds your model.
Copy link
Copy Markdown
Collaborator

@SamitHuang SamitHuang Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes thing complicated (e.g. no definition on what is a common parameter for diffusion).

BTW, is it really necessary to distinguish parameters from common and less-common (extra-args, like image_embeds)? why not parse them all via sampling_params?

Copy link
Copy Markdown
Contributor Author

@fhfuih fhfuih Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to our internal discussion, there indeed isn't a clear boundary between common and less common parameters. After my further investigation, I think maybe

  • Those parameters that have been inside extra_args stays there so that things don't break: cfg_text_scale, cfg_img_scale, audio_start_in_s. Note that many of them do make sense as sampling params
  • Those parameters that I newly move to extra_args in this PR, categorize them in four scenarios below:

  • Move to extra_args only if (1) they are related to the inference runtime or the encoding/embedding stage, but not the input data itself
    • cfg_normalization, cfg_truncation (only used in z-image)
    • enable_cfg_renorm, cfg_renorm_min, enable_prompt_rewrite (only in longcat)
    • num_waveforms_per_prompt (only in stable audio)
    • text_encoder_out_layers (only in flux klein)
    • Note: vLLM also shows example usage of extra_args for tokenizing and embedding routines
  • Those parameters that are used in many models and are related to runtime, promote them as an OmniDiffusionRequest property:
    • output_type (used in stable audio, several Alibaba models, stable audio, etc.) Make it default to None in OmniDiffusionRequest. Then in pipeline implementation, read it with model-specific fallback values (np or pil or others) that are consistent with their current implementation.
    • joint_attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs (in flux, longcat, z image, ovis), Make it default to empty values, and they are also always empty in the current implementations
  • Those parameters that are intended to be part of the input, it depends:
    • image_embeds, last_image (only used in wan i2v)---we can move them to req.prompt["multi_modal_data"]["image_embeds"] and req.prompt["multi_modal_data"]["last_image"] for clarity, but req.prompt is a list of single prompts, which results in a list of batch-1 embeds. For efficient data IO, we can also keep it in extra_args.
    • prompt_2, prompt_3, negative_prompt_2, negative_prompt_3 (used by both SD3 and flux)---promote them as OmniTextPrompt fields
    • pooled_prompt_embeds, negative_pooled_prompt_embeds (used by both SD3 and flux)---promote them as OmniTextPrompt fields, but same as the argument above about image_embeds.
    • prompt_embeds_mask,negative_prompt_embeds_mask (only in qwen image series)---same as above
    • Why keeping embeddings in extra_args also makes some sense:
      1. For image_embeds, some models always calculate it within the pipeline, but WAN allows user-input override. This is an unusual and not-unified pattern.
      2. For (neg)_prompt_embeds_mask and (neg_)pooled_prompt_embeds, embeddings are technically not expected in OmniTextPrompt. Plus these input data fields are not widely used by many models.

What do you think?


Below is an example way to extract the prompt strings and sampling parameters from the `OmniDiffusionRequest`.

```python
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.data import DiffusionOutput
Expand Down Expand Up @@ -366,14 +387,6 @@ def forward(
# ... rest of generation logic
```

For an image editing model, an example `OmniDiffusionRequest` is like:
```python
{
"prompt": "turn this cat to a dog",
"multi_modal_data": {"image": input_image}
},
```

**Wrap output:**

```diff
Expand Down
9 changes: 1 addition & 8 deletions docs/design/feature/cfg_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,7 @@ Call `self.diffuse` in your pipeline's forward function:
```python
import torch.nn as nn
class YourModelPipeline(nn.Module, CFGParallelMixin):
def forward(
self,
prompt: str,
negative_prompt: str | None = None,
guidance_scale: float = 3.5,
num_inference_steps: int = 50,
**kwargs,
):
def forward(self, req: OmniDiffusionRequest):
# Encode prompts, Initialize latents, Get timesteps
...
# Run diffusion loop (calls the mixin's diffuse method)
Expand Down
6 changes: 3 additions & 3 deletions docs/features/custom_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class CustomPipeline(QwenImageEditPipeline):
def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
super().__init__(od_config=od_config, prefix=prefix)

def forward(self, req, prompt=None, negative_prompt=None, **kwargs):
def forward(self, req):
# Call parent's forward to get normal output
output = super().forward(req=req, prompt=prompt, negative_prompt=negative_prompt, **kwargs)
output = super().forward(req=req)

# Add custom trajectory data
actual_num_steps = req.sampling_params.num_inference_steps or kwargs.get('num_inference_steps', 50)
actual_num_steps = req.sampling_params.num_inference_steps or 50
output.trajectory_timesteps = torch.linspace(1000, 0, actual_num_steps, dtype=torch.float32)
output.trajectory_latents = torch.randn(actual_num_steps, 1, 16, 64, 64, dtype=torch.float32)

Expand Down
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knlnguyen1802 I have updated this file and custom_pipeline.md to match the usage of forward function after #797 and the cleanup refactoring in this PR. See if that looks good to you

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fhfuih It is fine for me if the example can run successfully

Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import logging
from typing import Any

import PIL.Image
import torch

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
Expand All @@ -18,58 +16,13 @@ class CustomPipeline(QwenImageEditPipeline):
def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
super().__init__(od_config=od_config, prefix=prefix)

def forward(
self,
req: OmniDiffusionRequest,
prompt: str | list[str] | None = None,
negative_prompt: str | list[str] | None = None,
image: PIL.Image.Image | torch.Tensor | None = None,
true_cfg_scale: float = 4.0,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
sigmas: list[float] | None = None,
guidance_scale: float = 1.0,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds_mask: torch.Tensor | None = None,
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
) -> DiffusionOutput:
def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
"""Forward pass for image editing with dummy trajectory data."""
# Call parent's forward to get the normal output
output = super().forward(
req=req,
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
true_cfg_scale=true_cfg_scale,
height=height,
width=width,
num_inference_steps=num_inference_steps,
sigmas=sigmas,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
output_type=output_type,
attention_kwargs=attention_kwargs,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
output = super().forward(req=req)

# Get actual num_inference_steps used
actual_num_steps = req.sampling_params.num_inference_steps or num_inference_steps
actual_num_steps = req.sampling_params.num_inference_steps or 50

# Create dummy trajectory data
dummy_trajectory_latents = torch.randn(actual_num_steps, 1, 16, 64, 64, dtype=torch.float32)
Expand Down
96 changes: 59 additions & 37 deletions vllm_omni/diffusion/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,54 +583,76 @@ def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: boo
return False
return True

def forward(
self,
req: OmniDiffusionRequest,
prompt: str | list[str] | None = None,
prompt_2: str | list[str] | None = None,
negative_prompt: str | list[str] | None = None,
negative_prompt_2: str | list[str] | None = None,
true_cfg_scale: float = 1.0,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 28,
sigmas: list[float] | None = None,
guidance_scale: float = 3.5,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.FloatTensor | None = None,
prompt_embeds: torch.FloatTensor | None = None,
pooled_prompt_embeds: torch.FloatTensor | None = None,
negative_prompt_embeds: torch.FloatTensor | None = None,
negative_pooled_prompt_embeds: torch.FloatTensor | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
joint_attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
):
def forward(self, req: OmniDiffusionRequest):
"""Forward pass for flux."""
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
# TODO: May be some data formatting operations on the API side. Hack for now.
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts]

# For negative prompt, make it None if ALL are None---making it falsy and skipping CFG
# If only some of them are not None, only set those to empty strings---because we cannot skip CFG anyway.
if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts):
negative_prompt = None
elif req.prompts:
negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts]

req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts]
if any(p is not None for p in req_prompt_embeds):
try:
prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError
except TypeError:
raise ValueError(
"If you provide `prompt_embeds` for at least one prompt, you have to provide `prompt_embeds` for"
" all prompts so the pipeline can stack them together."
)
else:
prompt_embeds = None

req_negative_prompt_embeds = [
p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts
]
if any(p is not None for p in req_negative_prompt_embeds):
try:
negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError
except TypeError:
raise ValueError(
"If you provide `negative_prompt_embeds` for at least one prompt, "
"you have to provide `negative_prompt_embeds` for all prompts "
"so the pipeline can stack them together."
)
else:
negative_prompt_embeds = None

height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor
width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
sigmas = req.sampling_params.sigmas or sigmas
guidance_scale = (
req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale
)
generator = req.sampling_params.generator or generator
true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale
num_inference_steps = req.sampling_params.num_inference_steps or 28
sigmas = req.sampling_params.sigmas
if req.sampling_params.guidance_scale_provided:
guidance_scale = req.sampling_params.guidance_scale
else:
guidance_scale = 3.5
generator = req.sampling_params.generator
true_cfg_scale = req.sampling_params.true_cfg_scale or 1.0
max_sequence_length = req.sampling_params.max_sequence_length or 512
num_images_per_prompt = (
req.sampling_params.num_outputs_per_prompt
if req.sampling_params.num_outputs_per_prompt > 0
else num_images_per_prompt
req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1
)
latents = req.sampling_params.latents

prompt_2: str | list[str] | None = req.sampling_params.extra_args.get("prompt_2", None)
negative_prompt_2: str | list[str] | None = req.sampling_params.extra_args.get("negative_prompt_2", None)
pooled_prompt_embeds: torch.FloatTensor | None = req.sampling_params.extra_args.get(
"pooled_prompt_embeds", None
)
negative_pooled_prompt_embeds: torch.FloatTensor | None = req.sampling_params.extra_args.get(
"negative_pooled_prompt_embeds", None
)
output_type: str = req.sampling_params.extra_args.get("output_type", "pil")
joint_attention_kwargs: dict[str, Any] | None = req.sampling_params.extra_args.get(
"joint_attention_kwargs", None
)
callback_on_step_end_tensor_inputs: list[str] = req.sampling_params.extra_args.get(
"callback_on_step_end_tensor_inputs", ["latents"]
)

# 1. Check inputs. Raise error if not correct
Expand Down
Loading
Loading