Skip to content
90 changes: 90 additions & 0 deletions tests/diffusion/inputs/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Tests for sampling parameter override behaviors.
"""

from copy import deepcopy

from vllm_omni.inputs.data import DiffusionParamOverrides, OmniDiffusionSamplingParams


def test_merge_nothing():
"""Ensure that merging nothing doesn't break anything."""
user_params = OmniDiffusionSamplingParams()
overrides = DiffusionParamOverrides()
orig_params = deepcopy(user_params)
user_params.merge_with_def_params(overrides)
assert user_params.__dict__ == orig_params.__dict__
assert user_params._init_kwargs == set()


def test_merge_unset():
"""Ensure that we can override fields that are unset."""
default_steps = 777
user_params = OmniDiffusionSamplingParams()
overrides = DiffusionParamOverrides(num_inference_steps=default_steps)
user_params.merge_with_def_params(overrides)
assert user_params.num_inference_steps == 777
assert user_params._init_kwargs == set()


def test_merge_priority():
"""Ensure that explicitly passed values won't be overridden by pipelines."""
user_steps = 888
model_steps = 777
user_params = OmniDiffusionSamplingParams(
num_inference_steps=user_steps,
)
overrides = DiffusionParamOverrides(num_inference_steps=model_steps)
user_params.merge_with_def_params(overrides)
assert user_params.num_inference_steps == user_steps
assert user_params._init_kwargs == {"num_inference_steps"}


def test_merge_multiple():
"""Ensure that we can merge over truthy or falsy default values."""
model_steps = 888
model_resolution = 320
user_params = OmniDiffusionSamplingParams()
overrides = DiffusionParamOverrides(
num_inference_steps=model_steps, # Falsy (None) by default
resolution=model_resolution, # 640 by default
)
user_params.merge_with_def_params(overrides)
assert user_params.num_inference_steps == model_steps
assert user_params.resolution == model_resolution
assert user_params._init_kwargs == set()


def test_hierarchical_merge_complex():
"""Tests merge priority with multiple values."""
user_steps = 100
user_height = 100
user_width = 100
model_steps = 888 # clobbered by user steps
model_resolution = 320

user_params = OmniDiffusionSamplingParams(
num_inference_steps=user_steps,
height=user_height,
width=user_width,
)
overrides = DiffusionParamOverrides(
num_inference_steps=model_steps, # lower priority than user param
resolution=model_resolution,
)
user_params.merge_with_def_params(overrides)
assert user_params.num_inference_steps == user_steps
assert user_params.height == user_height
assert user_params.width == user_width
assert user_params.resolution == model_resolution
assert user_params._init_kwargs == {"num_inference_steps", "height", "width"}


def test_can_pass_falsy_override():
user_params = OmniDiffusionSamplingParams(num_inference_steps=None)
overrides = DiffusionParamOverrides(
num_inference_steps=100,
)
user_params.merge_with_def_params(overrides)
assert user_params.num_inference_steps is None
assert user_params._init_kwargs == {"num_inference_steps"}
88 changes: 88 additions & 0 deletions tests/diffusion/models/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fast interface checks for all Diffusion pipelines."""

from typing import cast

import pytest

from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline
from vllm_omni.diffusion.registry import DiffusionModelRegistry
from vllm_omni.inputs.data import DiffusionParamOverrides, OmniDiffusionSamplingParams

# Pipelines to omit from common tests; this should be done sparingly
# as the tests are generic, and only added for
SKIP_PIPELINES = ["DreamIDOmniPipeline"]

# Instance variables that need to be mocked for sampling_param_defaults
INSTANCE_VAR_MOCKS = {
"LTX2Pipeline": {"tokenizer_max_length": 512},
"LTX2ImageToVideoPipeline": {"tokenizer_max_length": 512},
}

TEST_PIPELINES = [pipe for pipe in DiffusionModelRegistry.models.keys() if pipe not in SKIP_PIPELINES]


@pytest.mark.parametrize("pipeline_type", TEST_PIPELINES)
def test_pipelines_are_vllm_diffusion_pipeline(pipeline_type):
"""Ensure all pipelines are instances of VllmDiffusionPipeline"""
pipe_class = DiffusionModelRegistry._try_load_model_cls(pipeline_type)
assert pipe_class is not None
assert issubclass(pipe_class, VllmDiffusionPipeline)


@pytest.mark.parametrize("pipeline_type", TEST_PIPELINES)
def test_pipeline_sampling_params_are_valid(pipeline_type):
"""Ensure all pipelines define sampling_param_defaults with valid param kwargs."""
pipe_class = DiffusionModelRegistry._try_load_model_cls(pipeline_type)
assert pipe_class is not None

# Create an uninitialized instance; this is easier than going through init/model load
# since the vast majority of models do not use instance vars in their default params
pipe_instance = object.__new__(pipe_class)

# Patch instance variables for any pipelines that do need it
if pipeline_type in INSTANCE_VAR_MOCKS:
for attr_name, attr_value in INSTANCE_VAR_MOCKS[pipeline_type].items():
setattr(pipe_instance, attr_name, attr_value)

# Verify sampling_param_defaults exists and has at least one key, since at a
# minimum every class will inherit num_inference_steps from the base class
defaults = pipe_instance.sampling_param_defaults
assert isinstance(defaults, DiffusionParamOverrides)
assert hasattr(defaults, "validated_overrides")
assert len(defaults.validated_overrides) > 0

# Ensure we can create a diffusion sampling params object (i.e., kwargs are valid)
params = OmniDiffusionSamplingParams(**defaults.validated_overrides)
for attr_name, val in defaults.validated_overrides.items():
assert hasattr(params, attr_name)
assert getattr(params, attr_name) == val


@pytest.mark.parametrize("pipeline_type", TEST_PIPELINES)
def test_merge_sampling_params(pipeline_type):
"""Test sampling param / override merging."""
USER_STEPS = 999 # overrides all pipeline defaults
pipe_class = DiffusionModelRegistry._try_load_model_cls(pipeline_type)
params = OmniDiffusionSamplingParams(num_inference_steps=USER_STEPS)
assert pipe_class is not None

# Create an uninitialized instance; this is easier than going through init/model load
# since the vast majority of models do not use instance vars in their default params
pipe_instance = cast(VllmDiffusionPipeline, object.__new__(pipe_class))

# Patch instance variables for any pipelines that do need it
if pipeline_type in INSTANCE_VAR_MOCKS:
for attr_name, attr_value in INSTANCE_VAR_MOCKS[pipeline_type].items():
setattr(pipe_instance, attr_name, attr_value)

defaults = pipe_instance.sampling_param_defaults
params.merge_with_def_params(defaults)

# Ensure the user override is prioritized for all models
assert params.num_inference_steps == USER_STEPS
# For every other property, it should match the pipeline defaults since user didn't pass it
for attr_name, val in defaults.validated_overrides.items():
if attr_name != "num_inference_steps":
assert getattr(params, attr_name) == val
12 changes: 10 additions & 2 deletions vllm_omni/diffusion/models/bagel/pipeline_bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import DiffusionParamOverrides
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific

from .autoencoder import AutoEncoder, AutoEncoderParams
Expand Down Expand Up @@ -148,12 +150,18 @@ def forward(self, packed_pixel_values, packed_flattened_position_ids, cu_seqlens
return outputs.last_hidden_state.squeeze(0)


class BagelPipeline(nn.Module, DiffusionPipelineProfilerMixin):
class BagelPipeline(VllmDiffusionPipeline, DiffusionPipelineProfilerMixin):
"""Bagel generation pipeline (MoT) packaged for vllm-omni diffusion engine.

This pipeline is self-contained and uses the ported Bagel core files.
"""

@property
def sampling_param_defaults(self):
return DiffusionParamOverrides(
num_inference_steps=50,
)

def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
super().__init__()
self.od_config = od_config
Expand Down Expand Up @@ -334,7 +342,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
cfg_renorm_min = extra_args.get("cfg_renorm_min", 0.0)

gen_params = BagelGenParams(
num_timesteps=int(req.sampling_params.num_inference_steps or 50),
num_timesteps=int(req.sampling_params.num_inference_steps),
timestep_shift=3.0,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import torch.distributed
from diffusers import FlowMatchEulerDiscreteScheduler
from PIL import Image, ImageOps
from torch import nn
from torchvision.transforms import Compose, Normalize
from tqdm import tqdm

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput
from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput, VllmDiffusionPipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import DiffusionParamOverrides

try:
from dreamid_omni.utils.divisible_crop import DivisibleCrop
Expand Down Expand Up @@ -74,9 +74,15 @@
}


class DreamIDOmniPipeline(nn.Module, CFGParallelMixin, SupportImageInput, SupportAudioInput):
class DreamIDOmniPipeline(VllmDiffusionPipeline, CFGParallelMixin, SupportImageInput, SupportAudioInput):
"""DreamID-Omni pipeline for vLLM-Omni."""

@property
def sampling_param_defaults(self):
return DiffusionParamOverrides(
num_inference_steps=50,
)

def __init__(
self,
*,
Expand Down
24 changes: 15 additions & 9 deletions vllm_omni/diffusion/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoConfig, CLIPTextModel, CLIPTokenizer, T5TokenizerFast
from vllm.model_executor.models.utils import AutoWeightsLoader

Expand All @@ -27,9 +26,11 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.flux import FluxTransformer2DModel
from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin
from vllm_omni.diffusion.models.interface import VllmDiffusionPipeline
from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import DiffusionParamOverrides
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,7 +65,15 @@ def post_process_func(images: torch.Tensor):
return post_process_func


class FluxPipeline(nn.Module, FluxPipelineMixin, CFGParallelMixin, DiffusionPipelineProfilerMixin):
class FluxPipeline(VllmDiffusionPipeline, FluxPipelineMixin, CFGParallelMixin, DiffusionPipelineProfilerMixin):
@property
def sampling_param_defaults(self):
return DiffusionParamOverrides(
num_inference_steps=28,
true_cfg_scale=1.0,
max_sequence_length=512,
)

def __init__(
self,
*,
Expand Down Expand Up @@ -494,14 +503,11 @@ def forward(
prompt_2: str | list[str] | None = None,
negative_prompt: str | list[str] | None = None,
negative_prompt_2: str | list[str] | None = None,
true_cfg_scale: float = 1.0,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 28,
sigmas: list[float] | None = None,
guidance_scale: float = 3.5,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.FloatTensor | None = None,
prompt_embeds: torch.FloatTensor | None = None,
pooled_prompt_embeds: torch.FloatTensor | None = None,
Expand All @@ -511,7 +517,6 @@ def forward(
return_dict: bool = True,
joint_attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
):
"""Forward pass for flux."""
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
Expand All @@ -524,13 +529,14 @@ def forward(

height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor
width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
num_inference_steps = req.sampling_params.num_inference_steps
sigmas = req.sampling_params.sigmas or sigmas
guidance_scale = (
req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale
)
generator = req.sampling_params.generator or generator
true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale
max_sequence_length = req.sampling_params.max_sequence_length
generator = req.sampling_params.generator
true_cfg_scale = req.sampling_params.true_cfg_scale
num_images_per_prompt = (
req.sampling_params.num_outputs_per_prompt
if req.sampling_params.num_outputs_per_prompt > 0
Expand Down
Loading