diff --git a/tests/diffusion/inputs/test_data.py b/tests/diffusion/inputs/test_data.py new file mode 100644 index 0000000000..c92d8adb02 --- /dev/null +++ b/tests/diffusion/inputs/test_data.py @@ -0,0 +1,90 @@ +""" +Tests for sampling parameter override behaviors. +""" + +from copy import deepcopy + +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniDiffusionSamplingParams + + +def test_merge_nothing(): + """Ensure that merging nothing doesn't break anything.""" + user_params = OmniDiffusionSamplingParams() + overrides = DiffusionParamOverrides() + orig_params = deepcopy(user_params) + user_params.merge_with_def_params(overrides) + assert user_params.__dict__ == orig_params.__dict__ + assert user_params._init_kwargs == set() + + +def test_merge_unset(): + """Ensure that we can override fields that are unset.""" + default_steps = 777 + user_params = OmniDiffusionSamplingParams() + overrides = DiffusionParamOverrides(num_inference_steps=default_steps) + user_params.merge_with_def_params(overrides) + assert user_params.num_inference_steps == 777 + assert user_params._init_kwargs == set() + + +def test_merge_priority(): + """Ensure that explicitly passed values won't be overridden by pipelines.""" + user_steps = 888 + model_steps = 777 + user_params = OmniDiffusionSamplingParams( + num_inference_steps=user_steps, + ) + overrides = DiffusionParamOverrides(num_inference_steps=model_steps) + user_params.merge_with_def_params(overrides) + assert user_params.num_inference_steps == user_steps + assert user_params._init_kwargs == {"num_inference_steps"} + + +def test_merge_multiple(): + """Ensure that we can merge over truthy or falsy default values.""" + model_steps = 888 + model_resolution = 320 + user_params = OmniDiffusionSamplingParams() + overrides = DiffusionParamOverrides( + num_inference_steps=model_steps, # Falsy (None) by default + resolution=model_resolution, # 640 by default + ) + user_params.merge_with_def_params(overrides) + assert user_params.num_inference_steps == model_steps + assert user_params.resolution == model_resolution + assert user_params._init_kwargs == set() + + +def test_hierarchical_merge_complex(): + """Tests merge priority with multiple values.""" + user_steps = 100 + user_height = 100 + user_width = 100 + model_steps = 888 # clobbered by user steps + model_resolution = 320 + + user_params = OmniDiffusionSamplingParams( + num_inference_steps=user_steps, + height=user_height, + width=user_width, + ) + overrides = DiffusionParamOverrides( + num_inference_steps=model_steps, # lower priority than user param + resolution=model_resolution, + ) + user_params.merge_with_def_params(overrides) + assert user_params.num_inference_steps == user_steps + assert user_params.height == user_height + assert user_params.width == user_width + assert user_params.resolution == model_resolution + assert user_params._init_kwargs == {"num_inference_steps", "height", "width"} + + +def test_can_pass_falsy_override(): + user_params = OmniDiffusionSamplingParams(num_inference_steps=None) + overrides = DiffusionParamOverrides( + num_inference_steps=100, + ) + user_params.merge_with_def_params(overrides) + assert user_params.num_inference_steps is None + assert user_params._init_kwargs == {"num_inference_steps"} diff --git a/tests/diffusion/models/test_base.py b/tests/diffusion/models/test_base.py new file mode 100644 index 0000000000..f99a9aac5e --- /dev/null +++ b/tests/diffusion/models/test_base.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Fast interface checks for all Diffusion pipelines.""" + +from typing import cast + +import pytest + +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline +from vllm_omni.diffusion.registry import DiffusionModelRegistry +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniDiffusionSamplingParams + +# Pipelines to omit from common tests; this should be done sparingly +# as the tests are generic, and only added for +SKIP_PIPELINES = ["DreamIDOmniPipeline"] + +# Instance variables that need to be mocked for sampling_param_defaults +INSTANCE_VAR_MOCKS = { + "LTX2Pipeline": {"tokenizer_max_length": 512}, + "LTX2ImageToVideoPipeline": {"tokenizer_max_length": 512}, +} + +TEST_PIPELINES = [pipe for pipe in DiffusionModelRegistry.models.keys() if pipe not in SKIP_PIPELINES] + + +@pytest.mark.parametrize("pipeline_type", TEST_PIPELINES) +def test_pipelines_are_vllm_diffusion_pipeline(pipeline_type): + """Ensure all pipelines are instances of VllmDiffusionPipeline""" + pipe_class = DiffusionModelRegistry._try_load_model_cls(pipeline_type) + assert pipe_class is not None + assert issubclass(pipe_class, VllmDiffusionPipeline) + + +@pytest.mark.parametrize("pipeline_type", TEST_PIPELINES) +def test_pipeline_sampling_params_are_valid(pipeline_type): + """Ensure all pipelines define sampling_param_defaults with valid param kwargs.""" + pipe_class = DiffusionModelRegistry._try_load_model_cls(pipeline_type) + assert pipe_class is not None + + # Create an uninitialized instance; this is easier than going through init/model load + # since the vast majority of models do not use instance vars in their default params + pipe_instance = object.__new__(pipe_class) + + # Patch instance variables for any pipelines that do need it + if pipeline_type in INSTANCE_VAR_MOCKS: + for attr_name, attr_value in INSTANCE_VAR_MOCKS[pipeline_type].items(): + setattr(pipe_instance, attr_name, attr_value) + + # Verify sampling_param_defaults exists and has at least one key, since at a + # minimum every class will inherit num_inference_steps from the base class + defaults = pipe_instance.sampling_param_defaults + assert isinstance(defaults, DiffusionParamOverrides) + assert hasattr(defaults, "validated_overrides") + assert len(defaults.validated_overrides) > 0 + + # Ensure we can create a diffusion sampling params object (i.e., kwargs are valid) + params = OmniDiffusionSamplingParams(**defaults.validated_overrides) + for attr_name, val in defaults.validated_overrides.items(): + assert hasattr(params, attr_name) + assert getattr(params, attr_name) == val + + +@pytest.mark.parametrize("pipeline_type", TEST_PIPELINES) +def test_merge_sampling_params(pipeline_type): + """Test sampling param / override merging.""" + USER_STEPS = 999 # overrides all pipeline defaults + pipe_class = DiffusionModelRegistry._try_load_model_cls(pipeline_type) + params = OmniDiffusionSamplingParams(num_inference_steps=USER_STEPS) + assert pipe_class is not None + + # Create an uninitialized instance; this is easier than going through init/model load + # since the vast majority of models do not use instance vars in their default params + pipe_instance = cast(VllmDiffusionPipeline, object.__new__(pipe_class)) + + # Patch instance variables for any pipelines that do need it + if pipeline_type in INSTANCE_VAR_MOCKS: + for attr_name, attr_value in INSTANCE_VAR_MOCKS[pipeline_type].items(): + setattr(pipe_instance, attr_name, attr_value) + + defaults = pipe_instance.sampling_param_defaults + params.merge_with_def_params(defaults) + + # Ensure the user override is prioritized for all models + assert params.num_inference_steps == USER_STEPS + # For every other property, it should match the pipeline defaults since user didn't pass it + for attr_name, val in defaults.validated_overrides.items(): + if attr_name != "num_inference_steps": + assert getattr(params, attr_name) == val diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 72e53e7f48..357ec8c4af 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -26,8 +26,10 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific from .autoencoder import AutoEncoder, AutoEncoderParams @@ -148,12 +150,18 @@ def forward(self, packed_pixel_values, packed_flattened_position_ids, cu_seqlens return outputs.last_hidden_state.squeeze(0) -class BagelPipeline(nn.Module, DiffusionPipelineProfilerMixin): +class BagelPipeline(VllmDiffusionPipeline, DiffusionPipelineProfilerMixin): """Bagel generation pipeline (MoT) packaged for vllm-omni diffusion engine. This pipeline is self-contained and uses the ported Bagel core files. """ + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + ) + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): super().__init__() self.od_config = od_config @@ -334,7 +342,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: cfg_renorm_min = extra_args.get("cfg_renorm_min", 0.0) gen_params = BagelGenParams( - num_timesteps=int(req.sampling_params.num_inference_steps or 50), + num_timesteps=int(req.sampling_params.num_inference_steps), timestep_shift=3.0, cfg_text_scale=cfg_text_scale, cfg_img_scale=cfg_img_scale, diff --git a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py index 974cc582f1..7955e486b7 100644 --- a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py +++ b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py @@ -9,15 +9,15 @@ import torch.distributed from diffusers import FlowMatchEulerDiscreteScheduler from PIL import Image, ImageOps -from torch import nn from torchvision.transforms import Compose, Normalize from tqdm import tqdm from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device -from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput +from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides try: from dreamid_omni.utils.divisible_crop import DivisibleCrop @@ -74,9 +74,15 @@ } -class DreamIDOmniPipeline(nn.Module, CFGParallelMixin, SupportImageInput, SupportAudioInput): +class DreamIDOmniPipeline(VllmDiffusionPipeline, CFGParallelMixin, SupportImageInput, SupportAudioInput): """DreamID-Omni pipeline for vLLM-Omni.""" + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + ) + def __init__( self, *, diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index 6f43e8dbb5..377346b550 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -16,7 +16,6 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import AutoConfig, CLIPTextModel, CLIPTokenizer, T5TokenizerFast from vllm.model_executor.models.utils import AutoWeightsLoader @@ -27,9 +26,11 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux import FluxTransformer2DModel from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific logger = logging.getLogger(__name__) @@ -64,7 +65,15 @@ def post_process_func(images: torch.Tensor): return post_process_func -class FluxPipeline(nn.Module, FluxPipelineMixin, CFGParallelMixin, DiffusionPipelineProfilerMixin): +class FluxPipeline(VllmDiffusionPipeline, FluxPipelineMixin, CFGParallelMixin, DiffusionPipelineProfilerMixin): + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=28, + true_cfg_scale=1.0, + max_sequence_length=512, + ) + def __init__( self, *, @@ -494,14 +503,11 @@ def forward( prompt_2: str | list[str] | None = None, negative_prompt: str | list[str] | None = None, negative_prompt_2: str | list[str] | None = None, - true_cfg_scale: float = 1.0, height: int | None = None, width: int | None = None, - num_inference_steps: int = 28, sigmas: list[float] | None = None, guidance_scale: float = 3.5, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, prompt_embeds: torch.FloatTensor | None = None, pooled_prompt_embeds: torch.FloatTensor | None = None, @@ -511,7 +517,6 @@ def forward( return_dict: bool = True, joint_attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ): """Forward pass for flux.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -524,13 +529,14 @@ def forward( height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps sigmas = req.sampling_params.sigmas or sigmas guidance_scale = ( req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale ) - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + max_sequence_length = req.sampling_params.max_sequence_length + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale num_images_per_prompt = ( req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py index c3bea7dd1c..ede599db6d 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py @@ -14,7 +14,6 @@ import numpy as np import PIL.Image import torch -import torch.nn as nn from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import TextualInversionLoaderMixin from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL @@ -30,11 +29,12 @@ FluxKontextTransformer2DModel, ) from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.logger import init_logger logger = init_logger(__name__) @@ -70,10 +70,18 @@ def post_process_func(images: torch.Tensor) -> list[PIL.Image.Image]: class FluxKontextPipeline( - nn.Module, FluxPipelineMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin + VllmDiffusionPipeline, FluxPipelineMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin ): """FLUX.1-Kontext pipeline for image editing with text guidance.""" + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=28, + true_cfg_scale=1.0, + max_sequence_length=512, + ) + support_image_input = True def __init__( @@ -467,11 +475,8 @@ def forward( negative_prompt_2: str | list[str] | None = None, height: int | None = None, width: int | None = None, - num_inference_steps: int = 28, guidance_scale: float = 3.5, - true_cfg_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, pooled_prompt_embeds: torch.Tensor | None = None, @@ -482,7 +487,6 @@ def forward( attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, sigmas: list[float] | None = None, ) -> DiffusionOutput: # Handle multiple prompts - only take the first one, similar to Flux2KleinPipeline @@ -514,9 +518,11 @@ def forward( image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) height = req.sampling_params.height or height width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps guidance_scale = req.sampling_params.guidance_scale or guidance_scale - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale + max_sequence_length = req.sampling_params.max_sequence_length latents = ( req.sampling_params.extra_args.get("latents") if req.sampling_params.extra_args.get("latents") is not None diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index 404f05b606..caa5f312c9 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -20,7 +20,6 @@ SYSTEM_MESSAGE_UPSAMPLING_T2I, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import AutoProcessor, Mistral3ForConditionalGeneration, PixtralProcessor from vllm.model_executor.models.utils import AutoWeightsLoader @@ -30,11 +29,12 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific logger = logging.getLogger(__name__) @@ -335,13 +335,22 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator = raise AttributeError("Could not access latents of provided encoder_output") -class Flux2Pipeline(nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin): +class Flux2Pipeline( + VllmDiffusionPipeline, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin +): """Flux2 pipeline for text-to-image generation.""" _callback_tensor_inputs = ["latents", "prompt_embeds"] support_image_input = True + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + max_sequence_length=512, + ) + def __init__( self, *, @@ -878,11 +887,9 @@ def forward( prompt: str | list[str] | None = None, height: int | None = None, width: int | None = None, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float | None = 4.0, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, @@ -891,7 +898,6 @@ def forward( attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, text_encoder_out_layers: tuple[int, ...] = (10, 20, 30), caption_upsample_temperature: float = None, ) -> DiffusionOutput: @@ -916,18 +922,18 @@ def forward( height = req.sampling_params.height or height width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps sigmas = req.sampling_params.sigmas or sigmas guidance_scale = ( req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale ) - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator num_images_per_prompt = ( req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else num_images_per_prompt ) - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + max_sequence_length = req.sampling_params.max_sequence_length text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers) req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 437dd58d0c..bd1681e8dc 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -24,7 +24,6 @@ import numpy as np import PIL.Image import torch -import torch.nn as nn from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders.autoencoder_kl_flux2 import AutoencoderKLFlux2 from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps @@ -42,10 +41,11 @@ from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( Flux2Transformer2DModel, ) -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific logger = init_logger(__name__) @@ -180,11 +180,18 @@ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: return float(mu) -class Flux2KleinPipeline(nn.Module, CFGParallelMixin, SupportImageInput, DiffusionPipelineProfilerMixin): +class Flux2KleinPipeline(VllmDiffusionPipeline, CFGParallelMixin, SupportImageInput, DiffusionPipelineProfilerMixin): """Flux2 klein pipeline for text-to-image generation.""" support_image_input = True + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + max_sequence_length=512, + ) + def __init__( self, *, @@ -656,11 +663,9 @@ def forward( prompt: str | list[str] | None = None, height: int | None = None, width: int | None = None, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float | None = 4.0, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, @@ -669,7 +674,6 @@ def forward( attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), ) -> DiffusionOutput: r""" @@ -768,18 +772,18 @@ def forward( height = req.sampling_params.height or height width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps sigmas = req.sampling_params.sigmas or sigmas guidance_scale = ( req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale ) - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator num_images_per_prompt = ( req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else num_images_per_prompt ) - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + max_sequence_length = req.sampling_params.max_sequence_length text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers) req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py index 375f7e7b80..858b2f16fd 100644 --- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -28,7 +28,6 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import ( ByT5Tokenizer, T5EncoderModel, @@ -46,9 +45,10 @@ GlmImageKVCache, GlmImageTransformer2DModel, ) +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -238,7 +238,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class GlmImagePipeline(nn.Module, DiffusionPipelineProfilerMixin): +class GlmImagePipeline(VllmDiffusionPipeline, DiffusionPipelineProfilerMixin): """ GLM-Image Pipeline for text-to-image and image-to-image generation. @@ -255,6 +255,12 @@ class GlmImagePipeline(nn.Module, DiffusionPipelineProfilerMixin): 4. VAE decodes final latents to image """ + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + ) + def __init__( self, *, @@ -722,7 +728,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Use image dimensions as default if available height = req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or 50 + num_inference_steps = req.sampling_params.num_inference_steps guidance_scale = req.sampling_params.guidance_scale or 1.5 self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds) diff --git a/vllm_omni/diffusion/models/helios/pipeline_helios.py b/vllm_omni/diffusion/models/helios/pipeline_helios.py index df709a515e..8899381d35 100644 --- a/vllm_omni/diffusion/models/helios/pipeline_helios.py +++ b/vllm_omni/diffusion/models/helios/pipeline_helios.py @@ -16,7 +16,6 @@ import torch.nn.functional as F from diffusers import AutoencoderKLWan from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import AutoConfig, AutoTokenizer, UMT5EncoderModel from vllm.model_executor.models.utils import AutoWeightsLoader @@ -26,9 +25,11 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.helios.helios_transformer import HeliosTransformer3DModel from vllm_omni.diffusion.models.helios.scheduling_helios import HeliosScheduler +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.platforms import current_omni_platform if TYPE_CHECKING: @@ -149,13 +150,20 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class HeliosPipeline(nn.Module, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin): +class HeliosPipeline(VllmDiffusionPipeline, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin): """Helios text-to-video / image-to-video / video-to-video pipeline for vllm-omni. Supports T2V, I2V (with image input), and V2V (with video input). Implements chunked video generation with multi-term memory history context. """ + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + max_sequence_length=226, + ) + def __init__( self, *, @@ -265,11 +273,9 @@ def forward( negative_prompt: str | None = None, height: int = 384, width: int = 640, - num_inference_steps: int = 50, guidance_scale: float = 5.0, frame_num: int = 132, output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, attention_kwargs: dict | None = None, @@ -342,7 +348,7 @@ def forward( height = req.sampling_params.height or height width = req.sampling_params.width or width num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale @@ -356,8 +362,7 @@ def forward( device = self.device dtype = self.transformer.dtype - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) @@ -368,7 +373,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 226, + max_sequence_length=req.sampling_params.max_sequence_length, device=device, dtype=dtype, ) diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py index 7e9e2d2787..0c63f436f1 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py @@ -21,8 +21,10 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from .autoencoder import AutoencoderKLConv3D from .hunyuan_image_3_tokenizer import TokenizerWrapper @@ -64,7 +66,9 @@ def to_device(data, device): return data -class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin): +class HunyuanImage3Pipeline( + VllmDiffusionPipeline, HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin +): _PROFILER_TARGETS = [ "model.forward", "model.layers[0].forward", @@ -74,6 +78,12 @@ class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, Diffu "vae.decode", ] + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + ) + def __init__(self, od_config: OmniDiffusionConfig) -> None: self.hf_config = get_config(od_config.model, trust_remote_code=True) super().__init__(self.hf_config) @@ -990,9 +1000,8 @@ def forward( image_size="auto", height: int = 1024, width: int = 1024, - num_inference_steps: int = 50, guidance_scale: float = 5.0, - generator: torch.Generator | list[torch.Generator] | None = None, + system_prompt: str | None = None, **kwargs, ) -> DiffusionOutput: extra_args = getattr(getattr(req, "sampling_params", None), "extra_args", {}) or {} @@ -1002,10 +1011,10 @@ def forward( system_prompt = get_system_prompt(use_system_prompt, "image", system_prompt) system_prompt = system_prompt.strip() if system_prompt is not None else "" prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator height = req.sampling_params.height or height width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale if guidance_scale <= 1.0: diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py index 6445bfee21..c4d45dbd2a 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py @@ -15,7 +15,6 @@ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from torch import nn from transformers import AutoConfig, ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer from vllm.model_executor.models.utils import AutoWeightsLoader @@ -25,10 +24,12 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.hunyuan_video.hunyuan_video_15_transformer import HunyuanVideo15Transformer3DModel from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -83,7 +84,13 @@ def post_process_func(video: torch.Tensor, output_type: str = "pil"): return post_process_func -class HunyuanVideo15Pipeline(nn.Module, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin): +class HunyuanVideo15Pipeline(VllmDiffusionPipeline, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin): + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + ) + def __init__( self, *, @@ -367,7 +374,6 @@ def predict_noise(self, **kwargs: Any) -> torch.Tensor: def forward( self, req: OmniDiffusionRequest, - num_inference_steps: int = 50, guidance_scale: float = 6.0, height: int = 480, width: int = 832, @@ -387,7 +393,7 @@ def forward( height = req.sampling_params.height or height width = req.sampling_params.width or width num_frames_val = req.sampling_params.num_frames if req.sampling_params.num_frames else num_frames - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py index c1acd1a895..07833506ae 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py @@ -15,7 +15,6 @@ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from torch import nn from transformers import ( AutoConfig, ByT5Tokenizer, @@ -37,12 +36,13 @@ get_hunyuan_video_15_post_process_func, retrieve_latents, ) -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -101,11 +101,17 @@ def pre_process_func(req: OmniDiffusionRequest) -> OmniDiffusionRequest: class HunyuanVideo15I2VPipeline( - nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin + VllmDiffusionPipeline, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin ): support_image_input = True color_format = "RGB" + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + ) + def __init__( self, *, @@ -431,7 +437,6 @@ def predict_noise(self, **kwargs: Any) -> torch.Tensor: def forward( self, req: OmniDiffusionRequest, - num_inference_steps: int = 50, guidance_scale: float = 6.0, height: int = 480, width: int = 832, @@ -470,7 +475,7 @@ def forward( height = req.sampling_params.height or height width = req.sampling_params.width or width num_frames_val = req.sampling_params.num_frames if req.sampling_params.num_frames else num_frames - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py index ef906472bd..f2b37d476b 100644 --- a/vllm_omni/diffusion/models/interface.py +++ b/vllm_omni/diffusion/models/interface.py @@ -10,6 +10,10 @@ runtime_checkable, ) +from torch import nn + +from vllm_omni.inputs.data import DiffusionParamOverrides + if TYPE_CHECKING: import torch @@ -17,6 +21,25 @@ from vllm_omni.diffusion.worker.utils import DiffusionRequestState +class VllmDiffusionPipeline(nn.Module): + """Base class for all vLLM Omni diffusion pipelines. + + All registered diffusion pipelines should inherit from this class. + Currently, this is only used for ensuring the correct sampling params + can be fetched for cache refresh, but additional common capabilities are + actively being added here. + + See the following RFC: https://github.com/vllm-project/vllm-omni/issues/2189 + """ + + @property + def sampling_param_defaults(self) -> DiffusionParamOverrides: + """Pipeline-specific default sampling parameters.""" + return DiffusionParamOverrides( + num_inference_steps=50, + ) + + @runtime_checkable class SupportImageInput(Protocol): support_image_input: ClassVar[bool] = True diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 76d3efa2f8..0daa0a88ea 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -18,7 +18,6 @@ from diffusers.pipelines.longcat_image.system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, SchedulerMixin from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader @@ -27,9 +26,11 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -200,7 +201,13 @@ def get_prompt_language(prompt): return "en" -class LongCatImagePipeline(nn.Module, CFGParallelMixin, DiffusionPipelineProfilerMixin): +class LongCatImagePipeline(VllmDiffusionPipeline, CFGParallelMixin, DiffusionPipelineProfilerMixin): + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + ) + def __init__( self, *, @@ -481,11 +488,9 @@ def forward( negative_prompt: str | list[str] | None = None, height: int | None = None, width: int | None = None, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 4.5, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, @@ -506,9 +511,9 @@ def forward( height = req.sampling_params.height or height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps sigmas = req.sampling_params.sigmas or sigmas - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator guidance_scale = ( req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale ) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 7eccf68636..66d89e51ea 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -16,7 +16,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import ( AutoTokenizer, Qwen2_5_VLForConditionalGeneration, @@ -29,14 +28,14 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import ( LongCatImageTransformer2DModel, ) from vllm_omni.diffusion.models.longcat_image.pipeline_longcat_image import calculate_shift from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -221,7 +220,15 @@ def split_quotation(prompt, quote_pairs=None): return result -class LongCatImageEditPipeline(nn.Module, CFGParallelMixin, SupportImageInput, DiffusionPipelineProfilerMixin): +class LongCatImageEditPipeline( + VllmDiffusionPipeline, CFGParallelMixin, SupportImageInput, DiffusionPipelineProfilerMixin +): + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + ) + def __init__( self, *, @@ -529,11 +536,9 @@ def forward( image: PIL.Image.Image | torch.Tensor | None = None, prompt: str | list[str] | None = None, negative_prompt: str | list[str] | None = None, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 3.5, num_images_per_prompt: int | None = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, @@ -558,13 +563,13 @@ def forward( guidance_scale = ( req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale ) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps num_images_per_prompt = ( req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt is not None else num_images_per_prompt ) - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py index c60b192f0a..ec1bf52664 100644 --- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py @@ -20,7 +20,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from torch import nn from transformers import AutoTokenizer, Gemma3ForConditionalGeneration from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader @@ -33,8 +32,10 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.lora.request import LoRARequest from .ltx2_transformer import LTX2VideoTransformer3DModel @@ -145,7 +146,14 @@ def step(self, noise_pred, t, latents, return_dict=False, generator=None): return ((video_out, audio_out),) -class LTX2Pipeline(nn.Module, CFGParallelMixin, ProgressBarMixin): +class LTX2Pipeline(VllmDiffusionPipeline, CFGParallelMixin, ProgressBarMixin): + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=40, + max_sequence_length=self.tokenizer_max_length, + ) + def __init__( self, *, @@ -755,7 +763,6 @@ def forward( width: int | None = None, num_frames: int | None = None, frame_rate: float | None = None, - num_inference_steps: int | None = None, sigmas: list[float] | None = None, timesteps: list[int] | None = None, guidance_scale: float = 4.0, @@ -774,7 +781,6 @@ def forward( output_type: str = "np", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, - max_sequence_length: int | None = None, ) -> DiffusionOutput: # Extract prompt/negative_prompt from request. # Input format: req.prompts is a list of str or dict with "prompt"/"negative_prompt" keys. @@ -788,7 +794,7 @@ def forward( width = req.sampling_params.width or width or 768 num_frames = req.sampling_params.num_frames or num_frames or 121 frame_rate = req.sampling_params.resolved_frame_rate or frame_rate or 24.0 - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps or 40 + num_inference_steps = req.sampling_params.num_inference_steps if timesteps is None: num_inference_steps = max(int(num_inference_steps), 2) elif len(timesteps) < 2: @@ -798,9 +804,7 @@ def forward( if req.sampling_params.num_outputs_per_prompt > 0 else num_videos_per_prompt or 1 ) - max_sequence_length = ( - req.sampling_params.max_sequence_length or max_sequence_length or self.tokenizer_max_length - ) + max_sequence_length = req.sampling_params.max_sequence_length if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale @@ -1151,7 +1155,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loader.load_weights(weights) -class LTX2TwoStagesPipeline(nn.Module): +class LTX2TwoStagesPipeline(VllmDiffusionPipeline): """LTX2TwoStagesPipeline is for two stages image to video generation""" def __init__( diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py index 65e7454b73..f5542b34e5 100644 --- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py @@ -291,7 +291,6 @@ def forward( width: int | None = None, num_frames: int | None = None, frame_rate: float | None = None, - num_inference_steps: int | None = None, sigmas: list[float] | None = None, timesteps: list[int] | None = None, guidance_scale: float = 4.0, @@ -310,7 +309,6 @@ def forward( output_type: str = "np", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, - max_sequence_length: int | None = None, ) -> DiffusionOutput: # Extract prompt/negative_prompt from request. # Input format: req.prompts is a list of str or dict with "prompt"/"negative_prompt" keys. @@ -324,7 +322,7 @@ def forward( width = req.sampling_params.width or width or 768 num_frames = req.sampling_params.num_frames or num_frames or 121 frame_rate = req.sampling_params.resolved_frame_rate or frame_rate or 24.0 - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps or 40 + num_inference_steps = req.sampling_params.num_inference_steps if timesteps is None: num_inference_steps = max(int(num_inference_steps), 2) elif len(timesteps) < 2: @@ -334,9 +332,7 @@ def forward( if req.sampling_params.num_outputs_per_prompt > 0 else num_videos_per_prompt or 1 ) - max_sequence_length = ( - req.sampling_params.max_sequence_length or max_sequence_length or self.tokenizer_max_length - ) + max_sequence_length = req.sampling_params.max_sequence_length if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale diff --git a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py index 4fa56ea931..b0e17ef29c 100644 --- a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py +++ b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py @@ -13,7 +13,6 @@ import torchvision.transforms as transforms from diffusers.image_processor import VaeImageProcessor from PIL import Image -from torch import nn from tqdm.auto import tqdm from transformers import AutoTokenizer, PreTrainedTokenizer from transformers.cache_utils import StaticCache @@ -28,6 +27,7 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.nextstep_1_1.modeling_flux_vae import AutoencoderKL from vllm_omni.diffusion.models.nextstep_1_1.modeling_nextstep import ( NextStepConfig, @@ -35,6 +35,7 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -131,7 +132,7 @@ def post_process_func(images: torch.Tensor): return post_process_func -class NextStep11Pipeline(nn.Module, DiffusionPipelineProfilerMixin): +class NextStep11Pipeline(VllmDiffusionPipeline, DiffusionPipelineProfilerMixin): """ NextStep-1.1 Pipeline for text-to-image generation. @@ -140,6 +141,12 @@ class NextStep11Pipeline(nn.Module, DiffusionPipelineProfilerMixin): to generate images autoregressively. """ + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=28, + ) + def __init__( self, *, @@ -558,11 +565,9 @@ def forward( prompt: str | list[str] | None = None, height: int | None = None, width: int | None = None, - num_inference_steps: int = 28, guidance_scale: float = 7.5, negative_prompt: str | list[str] | None = None, num_images_per_prompt: int = 1, - generator: torch.Generator | None = None, seed: int | None = None, **kwargs, ) -> DiffusionOutput: @@ -596,7 +601,7 @@ def forward( height = req.sampling_params.height or height or 512 width = req.sampling_params.width or width or 512 - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale num_images_per_prompt = ( @@ -618,6 +623,7 @@ def forward( positive_prompt = req.sampling_params.extra_args.get("positive_prompt", None) # Set seed for reproducibility (use generator if provided, else fall back to seed) + generator = req.sampling_params.generator if generator is None and seed is not None: set_seed(seed) elif generator is not None: diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index e8e307b878..6745804404 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -13,7 +13,6 @@ import numpy as np import PIL.Image import torch -import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.image_processor import ( @@ -32,13 +31,14 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.omnigen2.omnigen2_transformer import ( OmniGen2RotaryPosEmbed, OmniGen2Transformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -620,7 +620,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class OmniGen2Pipeline(CFGParallelMixin, nn.Module): +class OmniGen2Pipeline(VllmDiffusionPipeline, CFGParallelMixin): """ Pipeline for text-to-image generation using OmniGen2. @@ -960,6 +960,12 @@ def encode_prompt( negative_prompt_attention_mask, ) + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=28, + ) + @property def num_timesteps(self): return self._num_timesteps @@ -994,13 +1000,11 @@ def forward( max_pixels: int = 1024 * 1024, max_input_image_side_length: int = 1024, align_res: bool = True, - num_inference_steps: int = 28, text_guidance_scale: float = 4.0, image_guidance_scale: float = 1.0, cfg_range: tuple[float, float] = (0.0, 1.0), attention_kwargs: dict[str, Any] | None = None, timesteps: list[int] = None, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, verbose: bool = False, step_func=None, @@ -1026,8 +1030,8 @@ def forward( height = req.sampling_params.height or height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or width or self.default_sample_size * self.vae_scale_factor - generator = req.sampling_params.generator or generator - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + generator = req.sampling_params.generator + num_inference_steps = req.sampling_params.num_inference_steps if req.sampling_params.guidance_scale_provided: text_guidance_scale = req.sampling_params.guidance_scale self._text_guidance_scale = text_guidance_scale diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 8de617dd95..6ccbcc8d82 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -29,7 +29,6 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2TokenizerFast, Qwen3Model from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader @@ -38,9 +37,11 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import OvisImageTransformer2DModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific logger = init_logger(__name__) @@ -141,7 +142,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class OvisImagePipeline(nn.Module, CFGParallelMixin, DiffusionPipelineProfilerMixin): +class OvisImagePipeline(VllmDiffusionPipeline, CFGParallelMixin, DiffusionPipelineProfilerMixin): def __init__( self, *, @@ -528,6 +529,13 @@ def current_timestep(self): def interrupt(self): return self._interrupt + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + max_sequence_length=256, + ) + def forward( self, req: OmniDiffusionRequest, @@ -536,10 +544,8 @@ def forward( guidance_scale: float = 5.0, height: int | None = None, width: int | None = None, - num_inference_steps: int = 50, sigmas: list[float] | None = None, num_images_per_prompt: int | None = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, prompt_embeds: torch.FloatTensor | None = None, negative_prompt_embeds: torch.FloatTensor | None = None, @@ -548,7 +554,6 @@ def forward( joint_attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 256, ) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. @@ -628,12 +633,13 @@ def forward( height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps + max_sequence_length = req.sampling_params.max_sequence_length sigmas = req.sampling_params.sigmas or sigmas guidance_scale = ( req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale ) - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator num_images_per_prompt = ( req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 9f75c84538..91c880a81f 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -18,7 +18,6 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer from vllm.model_executor.models.utils import AutoWeightsLoader @@ -26,6 +25,7 @@ from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_qwenimage import DistributedAutoencoderKLQwenImage from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, ) @@ -38,6 +38,7 @@ normalize_min_aligned_size, ) from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import DiffusionParamOverrides if TYPE_CHECKING: from vllm_omni.diffusion.worker.utils import DiffusionRequestState @@ -244,9 +245,18 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) -class QwenImagePipeline(nn.Module, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin): +class QwenImagePipeline(VllmDiffusionPipeline, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin): supports_step_execution: ClassVar[bool] = True + # Overrides for default diffusion sampling params when using this pipeline + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + true_cfg_scale=4.0, + max_sequence_length=512, + ) + def __init__( self, *, @@ -697,13 +707,13 @@ def prepare_encode( negative_prompt=negative_prompt, height=sampling.height or self.default_sample_size * self.vae_scale_factor, width=sampling.width or self.default_sample_size * self.vae_scale_factor, - num_inference_steps=sampling.num_inference_steps or 50, + num_inference_steps=sampling.num_inference_steps, sigmas=sampling.sigmas, guidance_scale=sampling.guidance_scale if sampling.guidance_scale_provided else 1.0, num_images_per_prompt=sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1, generator=sampling.generator, - true_cfg_scale=sampling.true_cfg_scale or 4.0, - max_sequence_length=sampling.max_sequence_length or 512, + true_cfg_scale=sampling.true_cfg_scale, + max_sequence_length=sampling.max_sequence_length, attention_kwargs=kwargs.get("attention_kwargs"), ) @@ -866,15 +876,12 @@ def denoise_step( }, ) - true_cfg_scale = state.sampling.true_cfg_scale or 4.0 - cfg_normalize = state.sampling.cfg_normalize - return self.predict_noise_maybe_with_cfg( state.do_true_cfg, - true_cfg_scale, + state.sampling.true_cfg_scale, positive_kwargs, negative_kwargs, - cfg_normalize, + state.sampling.cfg_normalize, output_slice, ) @@ -918,14 +925,11 @@ def forward( req: OmniDiffusionRequest, prompt: str | list[str] | None = None, negative_prompt: str | list[str] | None = None, - true_cfg_scale: float = 4.0, height: int | None = None, width: int | None = None, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, @@ -934,7 +938,6 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ) -> DiffusionOutput: extracted_prompt, negative_prompt = self._extract_prompts(req.prompts) prompt = extracted_prompt or prompt @@ -942,11 +945,11 @@ def forward( height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor height, width = normalize_min_aligned_size(height, width, self.vae_scale_factor * 2) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + max_sequence_length = req.sampling_params.max_sequence_length + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale num_images_per_prompt = ( diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index dd77d71b1e..0e954cc898 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -20,14 +20,13 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, ) @@ -41,7 +40,7 @@ normalize_min_aligned_size, ) from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -217,7 +216,18 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class QwenImageEditPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin): +class QwenImageEditPipeline( + VllmDiffusionPipeline, SupportImageInput, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin +): + # Overrides for default diffusion sampling params when using this pipeline + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + true_cfg_scale=4.0, + max_sequence_length=512, + ) + def __init__( self, *, @@ -608,14 +618,11 @@ def forward( prompt: str | list[str] | None = None, negative_prompt: str | list[str] | None = None, image: PIL.Image.Image | torch.Tensor | None = None, - true_cfg_scale: float = 4.0, height: int | None = None, width: int | None = None, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, @@ -624,7 +631,6 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ) -> DiffusionOutput: """Forward pass for image editing.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -670,11 +676,11 @@ def forward( image = self.image_processor.preprocess(image, calculated_height, calculated_width) image = image.unsqueeze(2) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + max_sequence_length = req.sampling_params.max_sequence_length + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale num_images_per_prompt = ( diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 6f6c9d2ba3..6994f2476b 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -18,14 +18,13 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, ) @@ -44,7 +43,7 @@ normalize_min_aligned_size, ) from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -169,8 +168,17 @@ def post_process_func( class QwenImageEditPlusPipeline( - nn.Module, SupportImageInput, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin + VllmDiffusionPipeline, SupportImageInput, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin ): + # Overrides for default diffusion sampling params when using this pipeline + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + true_cfg_scale=4.0, + max_sequence_length=512, + ) + def __init__( self, *, @@ -541,14 +549,11 @@ def forward( prompt: str | list[str] | None = None, negative_prompt: str | list[str] | None = None, image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor | None = None, - true_cfg_scale: float = 4.0, height: int | None = None, width: int | None = None, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, @@ -557,7 +562,6 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ) -> DiffusionOutput: """Forward pass for image editing with support for multiple images.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -623,11 +627,11 @@ def forward( condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + max_sequence_length = req.sampling_params.max_sequence_length + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale num_images_per_prompt = ( diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 38866d89c5..6897fd7019 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -17,14 +17,13 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import ( AutoencoderKLQwenImage, ) @@ -40,7 +39,7 @@ normalize_min_aligned_size, ) from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -198,9 +197,20 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class QwenImageLayeredPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin): +class QwenImageLayeredPipeline( + VllmDiffusionPipeline, SupportImageInput, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin +): color_format = "RGBA" + # Overrides for default diffusion sampling params when using this pipeline + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + true_cfg_scale=4.0, + max_sequence_length=512, + ) + def __init__( self, *, @@ -589,13 +599,10 @@ def forward( image: PIL.Image.Image | torch.Tensor | None = None, prompt: str | list[str] | None = None, negative_prompt: str | list[str] | None = None, - true_cfg_scale: float = 4.0, layers: int | None = 4, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float | None = None, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, @@ -603,7 +610,6 @@ def forward( negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, - max_sequence_length: int = 512, resolution: int = 640, cfg_normalize: bool = False, use_en_prompt: bool = False, @@ -625,17 +631,17 @@ def forward( layers = req.sampling_params.layers if req.sampling_params.layers is not None else layers resolution = req.sampling_params.resolution if req.sampling_params.resolution is not None else resolution - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + max_sequence_length = req.sampling_params.max_sequence_length cfg_normalize = ( req.sampling_params.cfg_normalize if req.sampling_params.cfg_normalize is not None else cfg_normalize ) use_en_prompt = ( req.sampling_params.use_en_prompt if req.sampling_params.use_en_prompt is not None else use_en_prompt ) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps sigmas = req.sampling_params.sigmas or sigmas - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale num_images_per_prompt = ( diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 86b2a187ba..95d798fa4f 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -10,7 +10,6 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer from vllm.model_executor.models.utils import AutoWeightsLoader @@ -19,11 +18,13 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.sd3.sd3_transformer import ( SD3Transformer2DModel, ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -128,7 +129,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(nn.Module, CFGParallelMixin, DiffusionPipelineProfilerMixin): +class StableDiffusion3Pipeline(VllmDiffusionPipeline, CFGParallelMixin, DiffusionPipelineProfilerMixin): def __init__( self, *, @@ -499,6 +500,13 @@ def current_timestep(self): def interrupt(self): return self._interrupt + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=28, + max_sequence_length=256, + ) + def diffuse( self, latents: torch.Tensor, @@ -581,16 +589,13 @@ def forward( negative_prompt_3: str | list[str] = "", height: int | None = None, width: int | None = None, - num_inference_steps: int = 28, sigmas: list[float] | None = None, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, pooled_prompt_embeds: torch.Tensor | None = None, negative_pooled_prompt_embeds: torch.Tensor | None = None, - max_sequence_length: int = 256, ) -> DiffusionOutput: # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. @@ -602,9 +607,9 @@ def forward( height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - generator = req.sampling_params.generator or generator + max_sequence_length = req.sampling_params.max_sequence_length + num_inference_steps = req.sampling_params.num_inference_steps + generator = req.sampling_params.generator num_images_per_prompt = ( req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index 7950fb4915..61745035cf 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -19,7 +19,6 @@ from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel from diffusers.schedulers import CosineDPMSolverMultistepScheduler from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import T5EncoderModel, T5TokenizerFast from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader @@ -27,11 +26,12 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import SupportAudioOutput +from vllm_omni.diffusion.models.interface import SupportAudioOutput, VllmDiffusionPipeline from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import DiffusionParamOverrides logger = init_logger(__name__) @@ -60,7 +60,7 @@ def post_process_func( return post_process_func -class StableAudioPipeline(nn.Module, SupportAudioOutput, DiffusionPipelineProfilerMixin): +class StableAudioPipeline(VllmDiffusionPipeline, SupportAudioOutput, DiffusionPipelineProfilerMixin): """ Pipeline for text-to-audio generation using Stable Audio Open. @@ -72,6 +72,12 @@ class StableAudioPipeline(nn.Module, SupportAudioOutput, DiffusionPipelineProfil prefix: Weight prefix for loading (default: "") """ + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=100, + ) + def __init__( self, *, @@ -360,10 +366,8 @@ def forward( negative_prompt: str | list[str] | None = None, audio_end_in_s: float | None = None, audio_start_in_s: float = 0.0, - num_inference_steps: int = 100, guidance_scale: float = 7.0, num_waveforms_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, @@ -399,12 +403,11 @@ def forward( elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 84d89619e8..43d410fe41 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -22,13 +22,14 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -231,7 +232,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class Wan22Pipeline(nn.Module, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin): +class Wan22Pipeline(VllmDiffusionPipeline, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin): def __init__( self, *, @@ -357,6 +358,13 @@ def _create_transformer(self, config: dict) -> WanTransformer3DModel: """Create a transformer from a config dict. Subclasses may override.""" return create_transformer_from_config(config) + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=40, + max_sequence_length=512, + ) + @property def guidance_scale(self): return self._guidance_scale @@ -380,11 +388,9 @@ def forward( negative_prompt: str | None = None, height: int = 480, width: int = 832, - num_inference_steps: int = 40, guidance_scale: float | tuple[float, float] = 4.0, frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, attention_kwargs: dict | None = None, @@ -412,7 +418,7 @@ def forward( mod_value = self.vae_scale_factor_spatial * patch_size[1] # 16*2=32 for TI2V, 8*2=16 for I2V height = (height // mod_value) * mod_value width = (width // mod_value) * mod_value - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: @@ -467,8 +473,7 @@ def forward( dtype = self.text_encoder.dtype # Seed / generator - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) @@ -483,7 +488,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length, device=device, dtype=dtype, ) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 46484cd789..5ba4e6763f 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -22,7 +22,7 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( build_wan_scheduler, @@ -34,7 +34,7 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -145,7 +145,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: class Wan22I2VPipeline( - nn.Module, SupportImageInput, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin + VllmDiffusionPipeline, SupportImageInput, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin ): """ Wan2.2 Image-to-Video Pipeline. @@ -254,6 +254,13 @@ def __init__( enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler ) + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=40, + max_sequence_length=512, + ) + @property def guidance_scale(self): return self._guidance_scale @@ -293,11 +300,9 @@ def forward( image: PIL.Image.Image | torch.Tensor | None = None, height: int = 480, width: int = 832, - num_inference_steps: int = 40, guidance_scale: float | tuple[float, float] = 5.0, frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, @@ -340,7 +345,7 @@ def forward( height = req.sampling_params.height or height width = req.sampling_params.width or width num_frames = req.sampling_params.num_frames or frame_num - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: @@ -389,8 +394,7 @@ def forward( dtype = self.transformer.dtype # Generator setup - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) @@ -406,7 +410,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length, device=device, dtype=dtype, ) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 939fe294a3..09384a06b4 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -34,7 +34,7 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.interface import SupportImageInput, VllmDiffusionPipeline from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( build_wan_scheduler, @@ -45,7 +45,7 @@ retrieve_latents, ) from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.inputs.data import DiffusionParamOverrides, OmniTextPrompt from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -131,7 +131,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class Wan22TI2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin, ProgressBarMixin): +class Wan22TI2VPipeline(VllmDiffusionPipeline, SupportImageInput, CFGParallelMixin, ProgressBarMixin): """ Wan2.2 Text-Image-to-Video (TI2V) Pipeline. @@ -200,6 +200,13 @@ def __init__( self._num_timesteps = None self._current_timestep = None + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=40, + max_sequence_length=512, + ) + @property def guidance_scale(self): return self._guidance_scale @@ -224,11 +231,9 @@ def forward( image: PIL.Image.Image | torch.Tensor | None = None, height: int = 704, width: int = 1280, - num_inference_steps: int = 40, guidance_scale: float = 5.0, frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, attention_kwargs: dict | None = None, @@ -270,7 +275,7 @@ def forward( height = req.sampling_params.height or height width = req.sampling_params.width or width num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: @@ -298,8 +303,7 @@ def forward( dtype = self.transformer.dtype # Generator setup - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) @@ -310,7 +314,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length, device=device, dtype=dtype, ) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 5bea59a209..6082c4e615 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -23,7 +23,6 @@ import PIL.Image import torch -import torch.nn as nn from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import logging @@ -35,11 +34,13 @@ from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl import DistributedAutoencoderKL from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline from vllm_omni.diffusion.models.z_image.z_image_transformer import ( ZImageTransformer2DModel, ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import DiffusionParamOverrides from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -158,7 +159,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class ZImagePipeline(nn.Module, DiffusionPipelineProfilerMixin): +class ZImagePipeline(VllmDiffusionPipeline, DiffusionPipelineProfilerMixin): def __init__( self, *, @@ -374,6 +375,13 @@ def num_timesteps(self): def interrupt(self): return self._interrupt + @property + def sampling_param_defaults(self): + return DiffusionParamOverrides( + num_inference_steps=50, + max_sequence_length=512, + ) + def forward( self, req: OmniDiffusionRequest, @@ -382,14 +390,12 @@ def forward( strength: float = 0.6, height: int = 1024, width: int = 1024, - num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 5.0, cfg_normalization: bool = False, cfg_truncation: float = 1.0, negative_prompt: str | list[str] | None = None, num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, prompt_embeds: list[torch.FloatTensor] | None = None, negative_prompt_embeds: list[torch.FloatTensor] | None = None, @@ -398,7 +404,6 @@ def forward( joint_attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. @@ -520,10 +525,10 @@ def forward( height = req.sampling_params.height or height width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps generator = req.sampling_params.generator sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + max_sequence_length = req.sampling_params.max_sequence_length guidance_scale = ( req.sampling_params.guidance_scale if req.sampling_params.guidance_rescale is not None else guidance_scale ) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 535f053c38..acdddcc6d5 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -27,7 +27,7 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.forward_context import set_forward_context from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import supports_step_execution +from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline, supports_step_execution from vllm_omni.diffusion.offloader import get_offload_backend from vllm_omni.diffusion.registry import _NO_CACHE_ACCELERATION from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -218,6 +218,25 @@ def _record_peak_memory(self, output: DiffusionOutput) -> None: pool_overhead_gb / peak_reserved_gb * 100 if peak_reserved_gb > 0 else 0.0, ) + def _finalize_sampling_params(self, sampling_params): + """Finalizes the sampling params by adding pipeline defaults; + NOTE: This is done in place.""" + # Resolve the sampling params generator first + if sampling_params.generator is None and sampling_params.seed is not None: + if sampling_params.generator_device is not None: + gen_device = sampling_params.generator_device + elif self.device.type == "cpu": + gen_device = "cpu" + else: + gen_device = self.device + sampling_params.generator = torch.Generator(device=gen_device).manual_seed(sampling_params.seed) + + # Apply model specific defaults to unset fields + if isinstance(self.pipeline, VllmDiffusionPipeline): + logger.debug("Merging default sampling params into user request") + sampling_params.merge_with_def_params(self.pipeline.sampling_param_defaults) + return sampling_params + def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: """ Execute a forward pass for the given requests. @@ -238,6 +257,9 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: if len(req.prompts) == 0: raise ValueError("Cannot execute model with empty request list") + # set the generator and pipeline defaults for sampling params. + sampling_params = self._finalize_sampling_params(req.sampling_params) + # Use no_grad() for HSDP compatibility, inference_mode() otherwise for better perf use_hsdp = self.od_config.parallel_config.use_hsdp grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() @@ -249,23 +271,14 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: target_device=getattr(self.pipeline, "device", None), ) - if req.sampling_params.generator is None and req.sampling_params.seed is not None: - if req.sampling_params.generator_device is not None: - gen_device = req.sampling_params.generator_device - elif self.device.type == "cpu": - gen_device = "cpu" - else: - gen_device = self.device - req.sampling_params.generator = torch.Generator(device=gen_device).manual_seed(req.sampling_params.seed) - # Refresh cache context if needed if ( not getattr(req, "skip_cache_refresh", False) and self.cache_backend is not None and self.cache_backend.is_enabled() - and req.sampling_params.num_inference_steps is not None + and sampling_params.num_inference_steps is not None ): - self.cache_backend.refresh(self.pipeline, req.sampling_params.num_inference_steps) + self.cache_backend.refresh(self.pipeline, sampling_params.num_inference_steps) is_primary = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 if is_primary: @@ -357,16 +370,11 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner state, is_new_request = self._update_states(scheduler_output) if is_new_request: - # TODO: support kv manager recv - # TODO: support cache backend - if state.sampling.generator is None and state.sampling.seed is not None: - if state.sampling.generator_device is not None: - gen_device = state.sampling.generator_device - elif self.device.type == "cpu": - gen_device = "cpu" - else: - gen_device = self.device - state.sampling.generator = torch.Generator(device=gen_device).manual_seed(state.sampling.seed) + # set the generator and pipeline defaults for sampling params. + state.sampling = self._finalize_sampling_params(state.sampling) + + # TODO: support kv manager recv + # TODO: support cache backend with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): # step0/new request: encode diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 9cb6c44335..59552b6711 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -1,7 +1,9 @@ import copy import pprint -from dataclasses import asdict, dataclass, field -from typing import Any, TypeAlias, TypedDict +from collections.abc import Callable +from dataclasses import dataclass, field, fields +from functools import wraps +from typing import Any, TypeAlias, TypedDict, TypeVar from vllm.inputs import PromptType from vllm.sampling_params import SamplingParams @@ -18,6 +20,30 @@ import torch from vllm.inputs import EmbedsPrompt, TextPrompt, TokensInput, TokensPrompt +_T = TypeVar("_T") + + +def track_init_args(cls: type[_T]) -> type[_T]: + """Decorator that wraps __init__ to track which kwargs were explicitly + passed by the caller, so that merge_with_def_params can distinguish + 'caller set this to 0' from 'caller never touched this'. + + NOTE: This decorator preserves the original __init__ signature for + type checkers while adding runtime tracking of explicitly-passed kwargs. + """ + original_init: Callable[..., None] = cls.__init__ + + @wraps(original_init) + def new_init(self: _T, *args: Any, **kwargs: Any) -> None: + # Call the original init first (which sets _init_kwargs to empty set) + original_init(self, *args, **kwargs) + # Then track which keyword arguments were explicitly passed + self._init_kwargs: set[str] = set(kwargs.keys()) # type: ignore[attr-defined] + + # Replace __init__ - type: ignore needed due to limitations in typing dynamic method replacement + cls.__init__ = new_init # type: ignore[method-assign] + return cls + class OmniTextPrompt(TextPrompt): """Text prompt with optional embeddings and additional information. @@ -170,6 +196,7 @@ def token_inputs_omni( return inputs +@track_init_args @dataclass class OmniDiffusionSamplingParams: """ @@ -178,8 +205,15 @@ class OmniDiffusionSamplingParams: This dataclass contains all information needed during the diffusion pipeline execution, allowing methods to update specific components without needing to manage numerous individual parameters. + + The @track_init_args decorator records which kwargs the caller explicitly + passed, so merge_with_def_params can fill in pipeline defaults only for + fields the caller never touched. """ + # Set by the @track_init_args decorator at runtime; excluded from __init__ + _init_kwargs: set[str] = field(init=False, default_factory=set) + # Additional text-related parameters max_sequence_length: int | None = None prompt_template: dict[str, Any] | None = None @@ -234,8 +268,7 @@ class OmniDiffusionSamplingParams: step_index: int | None = None boundary_ratio: float | None = None - # Scheduler parameters – ``None`` means "not explicitly set by the caller"; - # each pipeline's ``forward()`` decides its own model-specific default. + # Scheduler parameters num_inference_steps: int | None = None guidance_scale: float = 0.0 guidance_scale_provided: bool = False @@ -247,7 +280,7 @@ class OmniDiffusionSamplingParams: eta: float = 0.0 sigmas: list[float] | None = None - true_cfg_scale: float | None = None # qwen-image specific now + true_cfg_scale: float | None = None # qwen-image specific for now n_tokens: int | None = None extra_step_kwargs: dict[str, Any] = field(default_factory=dict) @@ -327,10 +360,38 @@ def resolved_frame_rate(self) -> float | None: return float(fps) def __str__(self): - return pprint.pformat(asdict(self), indent=2, width=120) + return pprint.pformat({f.name: getattr(self, f.name) for f in fields(self)}, indent=2, width=120) def clone(self) -> "OmniDiffusionSamplingParams": return copy.deepcopy(self) + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain dict for IPC / serialization.""" + return {f.name: getattr(self, f.name) for f in fields(self)} + + def merge_with_def_params(self, def_params: "DiffusionParamOverrides"): + """Merges an instance of this class with a pipeline's defaults. + + Only fills in fields that the caller did not explicitly pass to + __init__; explicitly-set values (including falsy ones like 0 or + False) are preserved. + """ + for attr_name, attr_val in def_params.validated_overrides.items(): + if attr_name not in self._init_kwargs: + setattr(self, attr_name, attr_val) + + +class DiffusionParamOverrides: + """A wrapper around a dict mapping attribute names to sampling params.""" + + def __init__(self, **kwargs) -> None: + valid_keys = {f.name for f in fields(OmniDiffusionSamplingParams)} + for attr_name, attr_val in kwargs.items(): + if attr_name not in valid_keys: + raise AttributeError(f"{attr_name} is not a valid OmniDiffusionSamplingParams field") + + # TODO would be nice to validate types too + self.validated_overrides = kwargs + OmniSamplingParams: TypeAlias = SamplingParams | OmniDiffusionSamplingParams