Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
180 changes: 180 additions & 0 deletions tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests).
_DMD2_BASE = {
WanT2VDMD2Pipeline: Wan22Pipeline,
WanI2VDMD2Pipeline: Wan22I2VPipeline,
LTX2T2VDMD2Pipeline: LTX2Pipeline,
LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline,
}


def _make_pipeline(cls):
"""Run the DMD2 __init__ with the base pipeline mocked out (no model weights loaded)."""

base = _DMD2_BASE[cls]
od_config = MagicMock()
od_config.model = "/nonexistent"

def _mock_base_init(self, *a, **kw):
self.od_config = od_config

with patch.object(base, "__init__", _mock_base_init):
pipeline = object.__new__(cls)
torch.nn.Module.__init__(pipeline)
cls.__init__(pipeline, od_config=od_config)
return pipeline


def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest:
sp = OmniDiffusionSamplingParams(**sp_kwargs)
return OmniDiffusionRequest(
prompts=prompts or [{"prompt": "a cat dancing"}],
sampling_params=sp,
)


@pytest.fixture(
params=list(_DMD2_BASE.keys()),
ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"],
)
def pipeline(request):
return _make_pipeline(request.param)


# ---------------------------------------------------------------------------
# num_inference_steps
# ---------------------------------------------------------------------------


def test_num_inference_steps_forced_to_dmd2_value(pipeline):
req = _make_request(num_inference_steps=40)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps


def test_num_inference_steps_already_correct(pipeline):
req = _make_request(num_inference_steps=pipeline.num_inference_steps)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps


# ---------------------------------------------------------------------------
# guidance_scale
# ---------------------------------------------------------------------------


def test_guidance_scale_forced_to_one(pipeline):
req = _make_request(guidance_scale=5.0, guidance_scale_provided=True)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
assert req.sampling_params.guidance_scale_provided is False


def test_guidance_scale_already_correct(pipeline):
req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=False)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale


def test_guidance_scale_provided_flag_cleared(pipeline):
"""guidance_scale_provided=True must be cleared even if scale is already dmd2_guidance_scale."""
req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=True)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale_provided is False


def test_guidance_scale_2_cleared(pipeline):
req = _make_request(guidance_scale_2=3.0)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale_2 is None


def test_guidance_scale_2_unset_unchanged(pipeline):
req = _make_request()
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale_2 is None


def test_true_cfg_scale_cleared(pipeline):
req = _make_request(true_cfg_scale=2.0)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.true_cfg_scale is None


def test_do_classifier_free_guidance_forced_false(pipeline):
req = _make_request(do_classifier_free_guidance=True)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.do_classifier_free_guidance is False


def test_is_cfg_negative_forced_false(pipeline):
req = _make_request(is_cfg_negative=True)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.is_cfg_negative is False


def test_negative_prompt_stripped_from_prompt_dict(pipeline):
req = _make_request(prompts=[{"prompt": "a cat", "negative_prompt": "blurry"}])
pipeline._sanitize_dmd2_request(req)
assert "negative_prompt" not in req.prompts[0]
assert req.prompts[0]["prompt"] == "a cat"


def test_no_negative_prompt_unchanged(pipeline):
req = _make_request(prompts=[{"prompt": "a cat"}])
pipeline._sanitize_dmd2_request(req)
assert req.prompts[0] == {"prompt": "a cat"}


def test_string_prompt_not_mutated(pipeline):
"""String prompts (not dicts) must pass through unchanged."""
req = _make_request(prompts=["a cat dancing"])
pipeline._sanitize_dmd2_request(req)
assert req.prompts == ["a cat dancing"]


def test_multiple_prompts_all_sanitized(pipeline):
req = _make_request(
prompts=[
{"prompt": "a cat", "negative_prompt": "blurry"},
{"prompt": "a dog", "negative_prompt": "ugly"},
]
)
pipeline._sanitize_dmd2_request(req)
for p in req.prompts:
assert "negative_prompt" not in p


# ---------------------------------------------------------------------------
# Clean request — nothing changes
# ---------------------------------------------------------------------------


def test_clean_request_no_changes(pipeline):
req = _make_request(
guidance_scale=pipeline.dmd2_guidance_scale,
guidance_scale_provided=False,
do_classifier_free_guidance=False,
is_cfg_negative=False,
)
pipeline._sanitize_dmd2_request(req)
assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
assert req.sampling_params.guidance_scale_provided is False
assert req.sampling_params.guidance_scale_2 is None
assert req.sampling_params.true_cfg_scale is None
assert req.sampling_params.do_classifier_free_guidance is False
assert req.sampling_params.is_cfg_negative is False
93 changes: 93 additions & 0 deletions tests/diffusion/models/dmd2/test_dmd2_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

_DMD2_TIMESTEPS = [999, 937, 833, 624]

# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests).
_DMD2_BASE = {
WanT2VDMD2Pipeline: Wan22Pipeline,
WanI2VDMD2Pipeline: Wan22I2VPipeline,
LTX2T2VDMD2Pipeline: LTX2Pipeline,
LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline,
}


def _make_pipeline(cls):
"""Run the DMD2 __init__ (including __init_dmd2__) with the base pipeline mocked."""

base = _DMD2_BASE[cls]
od_config = MagicMock()
od_config.model = "/nonexistent"

def _mock_base_init(self, *a, **kw):
self.od_config = od_config # __init_dmd2__ needs this

with patch.object(base, "__init__", _mock_base_init):
pipeline = object.__new__(cls)
torch.nn.Module.__init__(pipeline)
cls.__init__(pipeline, od_config=od_config)
return pipeline


def _make_request(**sp_kwargs) -> OmniDiffusionRequest:
sp = OmniDiffusionSamplingParams(**sp_kwargs)
return OmniDiffusionRequest(prompts=[{"prompt": "a cat"}], sampling_params=sp)


@pytest.fixture(
params=list(_DMD2_BASE.keys()),
ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"],
)
def pipeline(request):
return _make_pipeline(request.param)


# ---------------------------------------------------------------------------
# forward() timestep injection
# ---------------------------------------------------------------------------


def _fake_parent_forward(self, req, *args, num_inference_steps=40, **kwargs):
"""Stub that calls set_timesteps as the real parent does."""
self.scheduler.set_timesteps(num_inference_steps, device="cpu")
return MagicMock()


def test_forward_timesteps_match_dmd2_schedule(pipeline):
"""After forward() runs, scheduler.timesteps must equal the DMD2 training schedule."""
parent = _DMD2_BASE[type(pipeline)]

# Baseline: calling set_timesteps(40) without the DMD2 override gives a different schedule
pipeline.scheduler.set_timesteps(40, device="cpu")
default_timesteps = pipeline.scheduler.timesteps.long().tolist()
assert default_timesteps == _DMD2_TIMESTEPS, (
"DMD2EulerScheduler should always return DMD2 timesteps regardless of num_steps"
)

with patch.object(parent, "forward", _fake_parent_forward):
pipeline.forward(_make_request())

assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS


def test_forward_timesteps_idempotent_across_calls(pipeline):
"""Successive forward() calls must not cause scheduler state to drift."""
parent = _DMD2_BASE[type(pipeline)]

with patch.object(parent, "forward", _fake_parent_forward):
pipeline.forward(_make_request())
pipeline.forward(_make_request())

assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS
Empty file.
8 changes: 8 additions & 0 deletions vllm_omni/diffusion/models/dmd2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm_omni.diffusion.models.dmd2.mixin import DMD2PipelineMixin

__all__ = [
"DMD2PipelineMixin",
]
88 changes: 88 additions & 0 deletions vllm_omni/diffusion/models/dmd2/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

import logging
import os

from vllm_omni.diffusion.data import DiffusionOutput
from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler
from vllm_omni.diffusion.models.utils import _load_json
from vllm_omni.diffusion.request import OmniDiffusionRequest

logger = logging.getLogger(__name__)


class DMD2PipelineMixin:
"""Mixin for FastGen DMD2-distilled models. Must appear before the base pipeline in MRO."""

def __init_dmd2__(self) -> None:
"""Call after super().__init__() to apply DMD2 scheduler and read model_index."""
local_files_only = os.path.exists(self.od_config.model)
try:
model_index = _load_json(self.od_config.model, "model_index.json", local_files_only)
except Exception:
model_index = {}

dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624])
self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4)
shift = model_index.get("dmd2_scheduler_shift", 1.0)
self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0)

self.scheduler = DMD2EulerScheduler(
num_train_timesteps=1000,
shift=shift,
dmd2_timesteps=dmd2_timesteps,
)

def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None:
"""Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts."""
sp = req.sampling_params

if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps:
logger.warning(
"DMD2: ignoring num_inference_steps=%d, forcing %d.",
sp.num_inference_steps,
self.num_inference_steps,
)
sp.num_inference_steps = self.num_inference_steps

if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale:
logger.warning(
"DMD2: ignoring guidance_scale=%.2f, forcing %.2f.",
sp.guidance_scale,
self.dmd2_guidance_scale,
)
sp.guidance_scale = self.dmd2_guidance_scale
sp.guidance_scale_provided = False

if sp.guidance_scale_2 is not None:
logger.warning("DMD2: ignoring guidance_scale_2.")
sp.guidance_scale_2 = None

if sp.true_cfg_scale is not None:
logger.warning("DMD2: ignoring true_cfg_scale.")
sp.true_cfg_scale = None

sp.do_classifier_free_guidance = False
sp.is_cfg_negative = False

fixed = []
for p in req.prompts:
if isinstance(p, dict) and "negative_prompt" in p:
logger.warning("DMD2: ignoring negative_prompt.")
p = {k: v for k, v in p.items() if k != "negative_prompt"}
fixed.append(p)
req.prompts = fixed

def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput:
self._sanitize_dmd2_request(req)
kwargs.pop("guidance_scale", None)
kwargs.pop("num_inference_steps", None)
return super().forward(
req,
guidance_scale=self.dmd2_guidance_scale,
num_inference_steps=self.num_inference_steps,
**kwargs,
)
4 changes: 4 additions & 0 deletions vllm_omni/diffusion/models/ltx2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@
from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import (
LTX2Pipeline,
LTX2T2VDMD2Pipeline,
LTX2TwoStagesPipeline,
create_transformer_from_config,
get_ltx2_post_process_func,
load_transformer_config,
)
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import (
LTX2I2VDMD2Pipeline,
LTX2ImageToVideoPipeline,
LTX2ImageToVideoTwoStagesPipeline,
)
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline

__all__ = [
"LTX2Pipeline",
"LTX2T2VDMD2Pipeline",
"LTX2ImageToVideoPipeline",
"LTX2I2VDMD2Pipeline",
"LTX2LatentUpsamplePipeline",
"LTX2TwoStagesPipeline",
"LTX2ImageToVideoTwoStagesPipeline",
Expand Down
Loading
Loading