Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import pytest
import torch

from vllm_omni.diffusion.models.flux.pipeline_flux import FluxDMD2Pipeline, FluxPipeline
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import QwenImageDMD2Pipeline, QwenImagePipeline
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams
Expand All @@ -19,6 +21,8 @@
WanI2VDMD2Pipeline: Wan22I2VPipeline,
LTX2T2VDMD2Pipeline: LTX2Pipeline,
LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline,
FluxDMD2Pipeline: FluxPipeline,
QwenImageDMD2Pipeline: QwenImagePipeline,
}


Expand Down Expand Up @@ -49,7 +53,7 @@ def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest:

@pytest.fixture(
params=list(_DMD2_BASE.keys()),
ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"],
ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v", "flux", "qwen_image"],
)
def pipeline(request):
return _make_pipeline(request.param)
Expand All @@ -63,13 +67,13 @@ def pipeline(request):
def test_num_inference_steps_forced_to_dmd2_value(pipeline):
req = _make_request(num_inference_steps=40)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps
assert req.sampling_params.num_inference_steps == pipeline.dmd2_config.num_inference_steps


def test_num_inference_steps_already_correct(pipeline):
req = _make_request(num_inference_steps=pipeline.num_inference_steps)
req = _make_request(num_inference_steps=pipeline.dmd2_config.num_inference_steps)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps
assert req.sampling_params.num_inference_steps == pipeline.dmd2_config.num_inference_steps


# ---------------------------------------------------------------------------
Expand All @@ -80,19 +84,19 @@ def test_num_inference_steps_already_correct(pipeline):
def test_guidance_scale_forced_to_one(pipeline):
req = _make_request(guidance_scale=5.0, guidance_scale_provided=True)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
assert req.sampling_params.guidance_scale == pipeline.dmd2_config.guidance_scale
assert req.sampling_params.guidance_scale_provided is False


def test_guidance_scale_already_correct(pipeline):
req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=False)
req = _make_request(guidance_scale=pipeline.dmd2_config.guidance_scale, guidance_scale_provided=False)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
assert req.sampling_params.guidance_scale == pipeline.dmd2_config.guidance_scale


def test_guidance_scale_provided_flag_cleared(pipeline):
"""guidance_scale_provided=True must be cleared even if scale is already dmd2_guidance_scale."""
req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=True)
req = _make_request(guidance_scale=pipeline.dmd2_config.guidance_scale, guidance_scale_provided=True)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale_provided is False

Expand Down Expand Up @@ -164,15 +168,44 @@ def test_multiple_prompts_all_sanitized(pipeline):
# ---------------------------------------------------------------------------


def test_sample_solver_stripped_from_extra_args(pipeline):
"""[C1] defense: sample_solver must not leak into req for the base pipeline to read."""
req = _make_request()
req.sampling_params.extra_args = {"sample_solver": "euler"}
pipeline._sanitize_dmd2_request(req)
assert "sample_solver" not in req.sampling_params.extra_args


def test_flow_shift_stripped_from_extra_args(pipeline):
"""[C1] defense: flow_shift must not leak into req for the base pipeline to read."""
req = _make_request()
req.sampling_params.extra_args = {"flow_shift": 3.0}
pipeline._sanitize_dmd2_request(req)
assert "flow_shift" not in req.sampling_params.extra_args


def test_unrelated_extra_args_preserved(pipeline):
"""Sanitizer only strips sample_solver / flow_shift; other extras pass through."""
req = _make_request()
req.sampling_params.extra_args = {"sample_solver": "euler", "unrelated": 42}
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.extra_args == {"unrelated": 42}


# ---------------------------------------------------------------------------
# Clean request — nothing changes
# ---------------------------------------------------------------------------


def test_clean_request_no_changes(pipeline):
req = _make_request(
guidance_scale=pipeline.dmd2_guidance_scale,
guidance_scale=pipeline.dmd2_config.guidance_scale,
guidance_scale_provided=False,
do_classifier_free_guidance=False,
is_cfg_negative=False,
)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
assert req.sampling_params.guidance_scale == pipeline.dmd2_config.guidance_scale
assert req.sampling_params.guidance_scale_provided is False
assert req.sampling_params.guidance_scale_2 is None
assert req.sampling_params.true_cfg_scale is None
Expand Down
51 changes: 49 additions & 2 deletions tests/diffusion/models/dmd2/test_dmd2_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,27 @@
import pytest
import torch

from vllm_omni.diffusion.models.flux.pipeline_flux import FluxDMD2Pipeline, FluxPipeline
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import QwenImageDMD2Pipeline, QwenImagePipeline
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

_DMD2_TIMESTEPS = [999, 937, 833, 624]
# Linspace fallback timesteps for num_inference_steps=4 (the mixin default when model_index is empty).
_DMD2_TIMESTEPS = [999, 749, 499, 249]

# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests).
_DMD2_BASE = {
WanT2VDMD2Pipeline: Wan22Pipeline,
WanI2VDMD2Pipeline: Wan22I2VPipeline,
LTX2T2VDMD2Pipeline: LTX2Pipeline,
LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline,
FluxDMD2Pipeline: FluxPipeline,
QwenImageDMD2Pipeline: QwenImagePipeline,
}


Expand Down Expand Up @@ -48,7 +53,7 @@ def _make_request(**sp_kwargs) -> OmniDiffusionRequest:

@pytest.fixture(
params=list(_DMD2_BASE.keys()),
ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"],
ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v", "flux", "qwen_image"],
)
def pipeline(request):
return _make_pipeline(request.param)
Expand Down Expand Up @@ -82,6 +87,48 @@ def test_forward_timesteps_match_dmd2_schedule(pipeline):
assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS


def test_default_solver_is_ode(pipeline):
"""Default dmd2_config.solver is 'ode' → scheduler.stochastic_sampling is False."""
assert pipeline.dmd2_config.solver == "ode"
assert pipeline.scheduler.config.stochastic_sampling is False


def test_sde_solver_plumbed_to_scheduler():
"""solver='sde' in model_index → scheduler.stochastic_sampling is True."""
from vllm_omni.diffusion.models.dmd2 import DMD2Config
from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler

cfg = DMD2Config.from_model_index({"dmd2_config": {"solver": "sde"}})
scheduler = DMD2EulerScheduler(
num_train_timesteps=1000,
shift=1.0,
dmd2_timesteps=cfg.resolve_timesteps(),
stochastic_sampling=(cfg.solver == "sde"),
)
assert scheduler.config.stochastic_sampling is True


def test_solver_case_insensitive():
"""'SDE', 'Sde', ' sde ' all normalize to 'sde'."""
from vllm_omni.diffusion.models.dmd2 import DMD2Config

for raw in ("SDE", "Sde", " sde ", "sde"):
cfg = DMD2Config.from_model_index({"dmd2_config": {"solver": raw}})
assert cfg.solver == "sde"


def test_solver_invalid_raises():
"""Unknown solver strings raise ValueError with a clear message."""
import pytest

from vllm_omni.diffusion.models.dmd2 import DMD2Config

with pytest.raises(ValueError, match="solver must be one of"):
DMD2Config.from_model_index({"dmd2_config": {"solver": "euler"}})
with pytest.raises(ValueError, match="solver must be one of"):
DMD2Config(solver="dpmpp") # type: ignore[arg-type]


def test_forward_timesteps_idempotent_across_calls(pipeline):
"""Successive forward() calls must not cause scheduler state to drift."""
parent = _DMD2_BASE[type(pipeline)]
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/diffusion/models/dmd2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm_omni.diffusion.models.dmd2.config import DMD2Config
from vllm_omni.diffusion.models.dmd2.mixin import DMD2PipelineMixin

__all__ = [
"DMD2Config",
"DMD2PipelineMixin",
]
46 changes: 46 additions & 0 deletions vllm_omni/diffusion/models/dmd2/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, get_args

import torch

Solver = Literal["ode", "sde"]


@dataclass
class DMD2Config:
"""Inference-time contract for a FastGen DMD2-distilled checkpoint."""

num_inference_steps: int = 4
denoising_timesteps: list[int] | None = None
solver: Solver = "ode"
guidance_scale: float = 1.0

def __post_init__(self) -> None:
if self.solver not in get_args(Solver):
raise ValueError(f"DMD2Config.solver must be one of {list(get_args(Solver))}, got {self.solver!r}")

@classmethod
def from_model_index(cls, model_index: dict) -> DMD2Config:
"""Read the `dmd2_config` block from a model_index.json dict. Missing block → defaults."""
block = model_index.get("dmd2_config", {})
solver = block.get("solver", cls.solver)
if isinstance(solver, str):
solver = solver.strip().lower()
return cls(
num_inference_steps=block.get("num_inference_steps", cls.num_inference_steps),
denoising_timesteps=block.get("denoising_timesteps"),
solver=solver,
guidance_scale=block.get("guidance_scale", cls.guidance_scale),
)
Comment thread
ayushag-nv marked this conversation as resolved.

def resolve_timesteps(self) -> list[int]:
if self.denoising_timesteps is not None:
return list(self.denoising_timesteps)
# Uniformly spaced timesteps from 999 down toward 0, excluding the final 0.
ts = torch.linspace(999, 0, self.num_inference_steps + 1)[:-1]
return ts.to(torch.int32).tolist()
35 changes: 21 additions & 14 deletions vllm_omni/diffusion/models/dmd2/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os

from vllm_omni.diffusion.data import DiffusionOutput
from vllm_omni.diffusion.models.dmd2.config import DMD2Config
from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler
from vllm_omni.diffusion.models.utils import _load_json
from vllm_omni.diffusion.request import OmniDiffusionRequest
Expand All @@ -25,36 +26,34 @@ def __init_dmd2__(self) -> None:
except Exception:
model_index = {}

dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624])
self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4)
shift = model_index.get("dmd2_scheduler_shift", 1.0)
self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0)
self.dmd2_config = DMD2Config.from_model_index(model_index)

self.scheduler = DMD2EulerScheduler(
num_train_timesteps=1000,
shift=shift,
dmd2_timesteps=dmd2_timesteps,
shift=1.0,
dmd2_timesteps=self.dmd2_config.resolve_timesteps(),
stochastic_sampling=(self.dmd2_config.solver == "sde"),
)
Comment thread
ayushag-nv marked this conversation as resolved.
Comment thread
ayushag-nv marked this conversation as resolved.

def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None:
"""Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts."""
sp = req.sampling_params

if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps:
if sp.num_inference_steps and sp.num_inference_steps != self.dmd2_config.num_inference_steps:
logger.warning(
"DMD2: ignoring num_inference_steps=%d, forcing %d.",
sp.num_inference_steps,
self.num_inference_steps,
self.dmd2_config.num_inference_steps,
)
sp.num_inference_steps = self.num_inference_steps
sp.num_inference_steps = self.dmd2_config.num_inference_steps

if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale:
if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_config.guidance_scale:
logger.warning(
"DMD2: ignoring guidance_scale=%.2f, forcing %.2f.",
sp.guidance_scale,
self.dmd2_guidance_scale,
self.dmd2_config.guidance_scale,
)
sp.guidance_scale = self.dmd2_guidance_scale
sp.guidance_scale = self.dmd2_config.guidance_scale
sp.guidance_scale_provided = False

if sp.guidance_scale_2 is not None:
Expand All @@ -68,6 +67,14 @@ def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None:
sp.do_classifier_free_guidance = False
sp.is_cfg_negative = False

# defense: strip scheduler-override extra_args that would let the base pipeline
# (e.g. Wan22Pipeline.forward) rebuild self.scheduler mid-forward and clobber DMD2EulerScheduler.
extra_args = getattr(sp, "extra_args", None) or {}
for key in ("sample_solver", "flow_shift"):
if key in extra_args:
logger.warning("DMD2: ignoring extra_args.%s.", key)
extra_args.pop(key)

fixed = []
for p in req.prompts:
if isinstance(p, dict) and "negative_prompt" in p:
Expand All @@ -82,7 +89,7 @@ def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput:
kwargs.pop("num_inference_steps", None)
return super().forward(
req,
guidance_scale=self.dmd2_guidance_scale,
num_inference_steps=self.num_inference_steps,
guidance_scale=self.dmd2_config.guidance_scale,
num_inference_steps=self.dmd2_config.num_inference_steps,
**kwargs,
)
2 changes: 2 additions & 0 deletions vllm_omni/diffusion/models/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
FluxTransformer2DModel,
)
from vllm_omni.diffusion.models.flux.pipeline_flux import (
FluxDMD2Pipeline,
FluxPipeline,
get_flux_post_process_func,
)
Expand All @@ -17,6 +18,7 @@

__all__ = [
"FluxPipeline",
"FluxDMD2Pipeline",
"FluxKontextPipeline",
"FluxTransformer2DModel",
"FluxKontextTransformer2DModel",
Expand Down
9 changes: 9 additions & 0 deletions vllm_omni/diffusion/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
from vllm_omni.diffusion.models.flux import FluxTransformer2DModel
from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin
from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel
Expand Down Expand Up @@ -665,3 +666,11 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)


class FluxDMD2Pipeline(DMD2PipelineMixin, FluxPipeline):
"""Flux pipeline for FastGen DMD2-distilled models."""

def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
super().__init__(od_config=od_config, prefix=prefix)
self.__init_dmd2__()
2 changes: 2 additions & 0 deletions vllm_omni/diffusion/models/qwen_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
QwenImageCFGParallelMixin,
)
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import (
QwenImageDMD2Pipeline,
QwenImagePipeline,
get_qwen_image_post_process_func,
)
Expand All @@ -16,6 +17,7 @@
__all__ = [
"QwenImageCFGParallelMixin",
"QwenImagePipeline",
"QwenImageDMD2Pipeline",
"QwenImageTransformer2DModel",
"get_qwen_image_post_process_func",
]
Loading
Loading