From 76b857176aa1fd48d11eb3191334fee6e68bf672 Mon Sep 17 00:00:00 2001 From: ayushag Date: Mon, 13 Apr 2026 12:30:53 -0700 Subject: [PATCH 1/7] chore: t2v pipeline for wan2.1 dmd2p Signed-off-by: ayushag --- vllm_omni/diffusion/models/wan2_2/__init__.py | 2 + .../models/wan2_2/pipeline_wan2_2.py | 114 ++++++++++++++++++ vllm_omni/diffusion/registry.py | 7 ++ 3 files changed, 123 insertions(+) diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py index d418001d952..b31781d2a30 100644 --- a/vllm_omni/diffusion/models/wan2_2/__init__.py +++ b/vllm_omni/diffusion/models/wan2_2/__init__.py @@ -3,6 +3,7 @@ from .pipeline_wan2_2 import ( Wan22Pipeline, + WanT2VDMD2Pipeline, create_transformer_from_config, get_wan22_post_process_func, get_wan22_pre_process_func, @@ -28,6 +29,7 @@ from .wan2_2_vace_transformer import VaceWanTransformerBlock, WanVACETransformer3DModel __all__ = [ + "WanT2VDMD2Pipeline", "Wan22Pipeline", "get_wan22_post_process_func", "get_wan22_pre_process_func", diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index a550e576f01..b8e6485dd9e 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -12,6 +12,7 @@ import PIL.Image import torch +from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, UMT5EncoderModel @@ -872,3 +873,116 @@ def check_inputs( if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") + + +# --------------------------------------------------------------------------- +# DMD2-distilled variant +# --------------------------------------------------------------------------- + + +class WanT2VDMD2Pipeline(Wan22Pipeline): + """Wan 2.1 T2V pipeline for FastGen DMD2-distilled 4-step models. + + Three changes from Wan22Pipeline: + - FlowMatchEulerDiscreteScheduler(shift=1.0): σ=t exactly, matching the + RFNoiseSchedule used during student training. + - Training timesteps [999, 937, 833, 624] injected via set_timesteps patch. + Euler's default 4-step schedule gives a different distribution. + - guidance_scale and negative_prompt are sanitized: teacher CFG is baked + into student weights so any user-supplied CFG is silently overridden. + """ + + GUIDANCE_SCALE = 1.0 + NUM_INFERENCE_STEPS = 4 + DMD2_TIMESTEPS = [999, 937, 833, 624] + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=1.0, + ) + + def _verify_dmd2_request(self, req: OmniDiffusionRequest) -> None: + """Sanitize CFG-related fields from the request. + + DMD2 student weights have teacher CFG baked in. Any CFG-triggering + field supplied by the caller would double-guide the model and degrade + quality. We override all of them with a single warning per field. + """ + sp = req.sampling_params + + if sp.guidance_scale_provided and sp.guidance_scale != self.GUIDANCE_SCALE: + logger.warning( + "DMD2: ignoring guidance_scale=%.2f — CFG is baked into student weights. Forcing guidance_scale=%.2f.", + sp.guidance_scale, + self.GUIDANCE_SCALE, + ) + sp.guidance_scale = self.GUIDANCE_SCALE + sp.guidance_scale_provided = False + + if sp.guidance_scale_2 is not None: + logger.warning("DMD2: ignoring guidance_scale_2 — not supported.") + sp.guidance_scale_2 = None + + # Classifier free guidance is baked in the student weights, set all related params to None + if sp.true_cfg_scale is not None: + logger.warning("DMD2: ignoring true_cfg_scale — not supported.") + sp.true_cfg_scale = None + + sp.do_classifier_free_guidance = False + sp.is_cfg_negative = False + + # negative_prompt comes from req.prompts, not a forward() param. + # Strip it so the parent's encode_prompt never encodes an uncond embedding. + fixed_prompts = [] + for p in req.prompts: + if isinstance(p, dict) and p.get("negative_prompt"): + logger.warning("DMD2: ignoring negative_prompt — not supported for DMD2 models.") + p = {k: v for k, v in p.items() if k != "negative_prompt"} + fixed_prompts.append(p) + req.prompts = fixed_prompts + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | None = None, + height: int = 480, + width: int = 832, + num_inference_steps: int = 4, + guidance_scale: float | tuple[float, float] = 1.0, + frame_num: int = 81, + output_type: str | None = "np", + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + attention_kwargs: dict | None = None, + **kwargs, + ) -> DiffusionOutput: + # DMD2: teacher CFG is baked into student weights. + # negative_prompt comes from req.prompts — sanitize it there. + # guidance_scale is forced to 1.0 regardless of what the caller passed. + self._verify_dmd2_request(req) + + _orig = self.scheduler.set_timesteps + + def _dmd2_set_timesteps(num_steps, device=None, **kw): + _orig(timesteps=self.DMD2_TIMESTEPS, device=device) + + self.scheduler.set_timesteps = _dmd2_set_timesteps + try: + return super().forward( + req, + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=1.0, + frame_num=frame_num, + output_type=output_type, + generator=generator, + prompt_embeds=prompt_embeds, + attention_kwargs=attention_kwargs, + **kwargs, + ) + finally: + self.scheduler.set_timesteps = _orig diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 97bc7fa2925..7a1c464e371 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -93,6 +93,11 @@ "pipeline_wan2_2_i2v", "Wan22I2VPipeline", ), + "WanT2VDMD2Pipeline": ( + "wan2_2", + "pipeline_wan2_2", + "WanT2VDMD2Pipeline", + ), "LongCatImagePipeline": ( "longcat_image", "pipeline_longcat_image", @@ -359,6 +364,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "LTX2ImageToVideoTwoStagesPipeline": "get_ltx2_post_process_func", "StableAudioPipeline": "get_stable_audio_post_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", + "WanT2VDMD2Pipeline": "get_wan22_post_process_func", "LongCatImagePipeline": "get_longcat_image_post_process_func", "BagelPipeline": "get_bagel_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", @@ -389,6 +395,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "WanPipeline": "get_wan22_pre_process_func", "WanVACEPipeline": "get_wan22_vace_pre_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func", + "WanT2VDMD2Pipeline": "get_wan22_pre_process_func", "OmniGen2Pipeline": "get_omnigen2_pre_process_func", "HeliosPipeline": "get_helios_pre_process_func", "HeliosPyramidPipeline": "get_helios_pre_process_func", From 391418113f1db0b68460a9b3ac0cdba5247aab61 Mon Sep 17 00:00:00 2001 From: ayushag Date: Mon, 13 Apr 2026 12:36:08 -0700 Subject: [PATCH 2/7] chore: i2v pipeline for wan 2.1 dmd2p Signed-off-by: ayushag --- vllm_omni/diffusion/models/wan2_2/__init__.py | 2 + .../models/wan2_2/pipeline_wan2_2_i2v.py | 117 ++++++++++++++++++ vllm_omni/diffusion/registry.py | 7 ++ 3 files changed, 126 insertions(+) diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py index b31781d2a30..97808df29d8 100644 --- a/vllm_omni/diffusion/models/wan2_2/__init__.py +++ b/vllm_omni/diffusion/models/wan2_2/__init__.py @@ -12,6 +12,7 @@ ) from .pipeline_wan2_2_i2v import ( Wan22I2VPipeline, + WanI2VDMD2Pipeline, get_wan22_i2v_post_process_func, get_wan22_i2v_pre_process_func, ) @@ -37,6 +38,7 @@ "load_transformer_config", "create_transformer_from_config", "Wan22I2VPipeline", + "WanI2VDMD2Pipeline", "get_wan22_i2v_post_process_func", "get_wan22_i2v_pre_process_func", "Wan22TI2VPipeline", diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index c05ecc9c9a2..482791e7972 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -12,6 +12,7 @@ import numpy as np import PIL.Image import torch +from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel @@ -851,3 +852,119 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +# --------------------------------------------------------------------------- +# DMD2-distilled variant +# --------------------------------------------------------------------------- + + +class WanI2VDMD2Pipeline(Wan22I2VPipeline): + """Wan 2.1 I2V pipeline for FastGen DMD2-distilled 4-step models. + + Three changes from Wan22I2VPipeline: + - FlowMatchEulerDiscreteScheduler(shift=1.0): σ=t exactly, matching the + RFNoiseSchedule used during student training. shift>1 distorts sigmas and + causes noisy output (shift=5.0 → final step dt=-0.892 vs correct -0.624). + - Training timesteps [999, 937, 833, 624] injected via set_timesteps patch. + Euler's default 4-step schedule gives a different distribution. + - guidance_scale and negative_prompt are sanitized: teacher CFG is baked + into student weights so any user-supplied CFG is silently overridden. + """ + + GUIDANCE_SCALE = 1.0 + NUM_INFERENCE_STEPS = 4 + DMD2_TIMESTEPS = [999, 937, 833, 624] + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=1.0, + ) + + def _verify_dmd2_request(self, req: OmniDiffusionRequest) -> None: + """Sanitize CFG-related fields from the request. + + DMD2 student weights have teacher CFG baked in. Any CFG-triggering + field supplied by the caller would double-guide the model and degrade + quality. We override all of them with a single warning per field. + """ + sp = req.sampling_params + + if sp.guidance_scale_provided and sp.guidance_scale != self.GUIDANCE_SCALE: + logger.warning( + "DMD2: ignoring guidance_scale=%.2f — CFG is baked into student weights. Forcing guidance_scale=%.2f.", + sp.guidance_scale, + self.GUIDANCE_SCALE, + ) + sp.guidance_scale = self.GUIDANCE_SCALE + sp.guidance_scale_provided = False + + if sp.guidance_scale_2 is not None: + logger.warning("DMD2: ignoring guidance_scale_2 — not supported.") + sp.guidance_scale_2 = None + + if sp.true_cfg_scale is not None: + logger.warning("DMD2: ignoring true_cfg_scale — not supported.") + sp.true_cfg_scale = None + + sp.do_classifier_free_guidance = False + sp.is_cfg_negative = False + + # negative_prompt comes from req.prompts, not a forward() param. + # Strip it so the parent's encode_prompt never encodes an uncond embedding. + fixed_prompts = [] + for p in req.prompts: + if isinstance(p, dict) and p.get("negative_prompt"): + logger.warning("DMD2: ignoring negative_prompt — not supported for DMD2 models.") + p = {k: v for k, v in p.items() if k != "negative_prompt"} + fixed_prompts.append(p) + req.prompts = fixed_prompts + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | None = None, + image: PIL.Image.Image | torch.Tensor | None = None, + height: int = 480, + width: int = 832, + num_inference_steps: int = 4, + guidance_scale: float | tuple[float, float] = 1.0, + frame_num: int = 81, + output_type: str | None = "np", + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + last_image: PIL.Image.Image | torch.Tensor | None = None, + attention_kwargs: dict | None = None, + **kwargs, + ) -> DiffusionOutput: + self._verify_dmd2_request(req) + + _orig = self.scheduler.set_timesteps + + def _dmd2_set_timesteps(num_steps, device=None, **kw): + _orig(timesteps=self.DMD2_TIMESTEPS, device=device) + + self.scheduler.set_timesteps = _dmd2_set_timesteps + try: + return super().forward( + req, + prompt=prompt, + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=self.GUIDANCE_SCALE, + frame_num=frame_num, + output_type=output_type, + generator=generator, + prompt_embeds=prompt_embeds, + image_embeds=image_embeds, + last_image=last_image, + attention_kwargs=attention_kwargs, + **kwargs, + ) + finally: + self.scheduler.set_timesteps = _orig diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 7a1c464e371..81d2e8270c4 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -98,6 +98,11 @@ "pipeline_wan2_2", "WanT2VDMD2Pipeline", ), + "WanI2VDMD2Pipeline": ( + "wan2_2", + "pipeline_wan2_2_i2v", + "WanI2VDMD2Pipeline", + ), "LongCatImagePipeline": ( "longcat_image", "pipeline_longcat_image", @@ -365,6 +370,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "StableAudioPipeline": "get_stable_audio_post_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", "WanT2VDMD2Pipeline": "get_wan22_post_process_func", + "WanI2VDMD2Pipeline": "get_wan22_i2v_post_process_func", "LongCatImagePipeline": "get_longcat_image_post_process_func", "BagelPipeline": "get_bagel_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", @@ -396,6 +402,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "WanVACEPipeline": "get_wan22_vace_pre_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func", "WanT2VDMD2Pipeline": "get_wan22_pre_process_func", + "WanI2VDMD2Pipeline": "get_wan22_i2v_pre_process_func", "OmniGen2Pipeline": "get_omnigen2_pre_process_func", "HeliosPipeline": "get_helios_pre_process_func", "HeliosPyramidPipeline": "get_helios_pre_process_func", From 73c88df7d9f52c7d12f93def6f9c995689d4aa12 Mon Sep 17 00:00:00 2001 From: ayushag Date: Mon, 13 Apr 2026 13:06:57 -0700 Subject: [PATCH 3/7] chore: added unit tests Signed-off-by: ayushag --- tests/diffusion/models/wan2_2/__init__.py | 0 .../test_wan_dmd2_request_sanitization.py | 146 ++++++++++++++++++ .../models/wan2_2/test_wan_dmd2_scheduler.py | 96 ++++++++++++ 3 files changed, 242 insertions(+) create mode 100644 tests/diffusion/models/wan2_2/__init__.py create mode 100644 tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py create mode 100644 tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py diff --git a/tests/diffusion/models/wan2_2/__init__.py b/tests/diffusion/models/wan2_2/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py b/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py new file mode 100644 index 00000000000..f552bdf3579 --- /dev/null +++ b/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import WanT2VDMD2Pipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import WanI2VDMD2Pipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_pipeline(cls): + """Instantiate a DMD2 pipeline without loading any model weights.""" + import torch + pipeline = object.__new__(cls) + torch.nn.Module.__init__(pipeline) + return pipeline + + +def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest: + """Build a minimal OmniDiffusionRequest with given sampling params.""" + sp = OmniDiffusionSamplingParams(**sp_kwargs) + return OmniDiffusionRequest( + prompts=prompts or [{"prompt": "a cat dancing"}], + sampling_params=sp, + ) + + +@pytest.fixture(params=[WanT2VDMD2Pipeline, WanI2VDMD2Pipeline], ids=["t2v", "i2v"]) +def pipeline(request): + return _make_pipeline(request.param) + + +# --------------------------------------------------------------------------- +# guidance_scale +# --------------------------------------------------------------------------- + +def test_guidance_scale_forced_to_one(pipeline): + req = _make_request(guidance_scale=5.0, guidance_scale_provided=True) + pipeline._verify_dmd2_request(req) + assert req.sampling_params.guidance_scale == 1.0 + assert req.sampling_params.guidance_scale_provided is False + + +def test_guidance_scale_already_correct(pipeline): + req = _make_request(guidance_scale=1.0, guidance_scale_provided=False) + pipeline._verify_dmd2_request(req) + assert req.sampling_params.guidance_scale == 1.0 + + +def test_guidance_scale_provided_flag_cleared(pipeline): + """guidance_scale_provided=True must be cleared even if scale is already 1.0.""" + req = _make_request(guidance_scale=1.0, guidance_scale_provided=True) + pipeline._verify_dmd2_request(req) + assert req.sampling_params.guidance_scale_provided is False + +def test_guidance_scale_2_cleared(pipeline): + req = _make_request(guidance_scale_2=3.0) + pipeline._verify_dmd2_request(req) + assert req.sampling_params.guidance_scale_2 is None + + +def test_guidance_scale_2_unset_unchanged(pipeline): + req = _make_request() + pipeline._verify_dmd2_request(req) + assert req.sampling_params.guidance_scale_2 is None + + +# --------------------------------------------------------------------------- +# CFG flags +# --------------------------------------------------------------------------- + +def test_true_cfg_scale_cleared(pipeline): + req = _make_request(true_cfg_scale=2.0) + pipeline._verify_dmd2_request(req) + assert req.sampling_params.true_cfg_scale is None + + +def test_do_classifier_free_guidance_forced_false(pipeline): + req = _make_request(do_classifier_free_guidance=True) + pipeline._verify_dmd2_request(req) + assert req.sampling_params.do_classifier_free_guidance is False + + +def test_is_cfg_negative_forced_false(pipeline): + req = _make_request(is_cfg_negative=True) + pipeline._verify_dmd2_request(req) + assert req.sampling_params.is_cfg_negative is False + + +# --------------------------------------------------------------------------- +# negative_prompt in prompt dict +# --------------------------------------------------------------------------- + +def test_negative_prompt_stripped_from_prompt_dict(pipeline): + req = _make_request(prompts=[{"prompt": "a cat", "negative_prompt": "blurry"}]) + pipeline._verify_dmd2_request(req) + assert "negative_prompt" not in req.prompts[0] + assert req.prompts[0]["prompt"] == "a cat" + + +def test_no_negative_prompt_unchanged(pipeline): + req = _make_request(prompts=[{"prompt": "a cat"}]) + pipeline._verify_dmd2_request(req) + assert req.prompts[0] == {"prompt": "a cat"} + + +def test_string_prompt_not_mutated(pipeline): + """String prompts (not dicts) must pass through unchanged.""" + req = _make_request(prompts=["a cat dancing"]) + pipeline._verify_dmd2_request(req) + assert req.prompts == ["a cat dancing"] + + +def test_multiple_prompts_all_sanitized(pipeline): + req = _make_request(prompts=[ + {"prompt": "a cat", "negative_prompt": "blurry"}, + {"prompt": "a dog", "negative_prompt": "ugly"}, + ]) + pipeline._verify_dmd2_request(req) + for p in req.prompts: + assert "negative_prompt" not in p + + +# --------------------------------------------------------------------------- +# Clean request — nothing changes +# --------------------------------------------------------------------------- + +def test_clean_request_no_changes(pipeline): + req = _make_request( + guidance_scale=1.0, + guidance_scale_provided=False, + do_classifier_free_guidance=False, + is_cfg_negative=False, + ) + pipeline._verify_dmd2_request(req) + assert req.sampling_params.guidance_scale == 1.0 + assert req.sampling_params.guidance_scale_provided is False + assert req.sampling_params.guidance_scale_2 is None + assert req.sampling_params.true_cfg_scale is None + assert req.sampling_params.do_classifier_free_guidance is False + assert req.sampling_params.is_cfg_negative is False diff --git a/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py b/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py new file mode 100644 index 00000000000..011baafc8ba --- /dev/null +++ b/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +from unittest.mock import MagicMock, patch + +import pytest +from diffusers import FlowMatchEulerDiscreteScheduler + +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import WanT2VDMD2Pipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import WanI2VDMD2Pipeline +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +def _make_pipeline(cls): + """ + Instantiate a DMD2 pipeline letting __init__ run — but mocking the + parent's __init__ so no model weights are loaded. + + This verifies that the DMD2 class itself (not the test helper) sets up + the correct scheduler. + """ + import torch + parent = cls.__bases__[0] # Wan22Pipeline or Wan22I2VPipeline + with patch.object(parent, "__init__", lambda *a, **kw: None): + pipeline = object.__new__(cls) + torch.nn.Module.__init__(pipeline) + cls.__init__(pipeline, od_config=MagicMock()) + return pipeline + + +@pytest.fixture(params=[WanT2VDMD2Pipeline, WanI2VDMD2Pipeline], ids=["t2v", "i2v"]) +def pipeline(request): + return _make_pipeline(request.param) + + + +from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams + + +def _make_request(**sp_kwargs) -> OmniDiffusionRequest: + sp = OmniDiffusionSamplingParams(**sp_kwargs) + return OmniDiffusionRequest( + prompts=[{"prompt": "a cat"}], + sampling_params=sp, + ) + + + +def test_scheduler_is_euler(pipeline): + """DMD2 __init__ must replace the parent's UniPC with Euler.""" + assert isinstance(pipeline.scheduler, FlowMatchEulerDiscreteScheduler) + + + +def _fake_parent_forward(self, req, *args, num_inference_steps=40, **kwargs): + """Minimal parent forward() stub: calls set_timesteps exactly as the real parent does.""" + self.scheduler.set_timesteps(num_inference_steps, device="cpu") + return MagicMock() + +def test_forward_timesteps_match_dmd2_schedule(pipeline): + """ + After forward() runs, scheduler.timesteps must equal DMD2_TIMESTEPS. + """ + parent = type(pipeline).__bases__[0] + + # Baseline: Euler scheduler with num_steps=40 gives a different schedule + pipeline.scheduler.set_timesteps(40, device="cpu") + default_timesteps = pipeline.scheduler.timesteps.long().tolist() + assert default_timesteps != pipeline.DMD2_TIMESTEPS, ( + "Euler scheduler default 40-step schedule unexpectedly matches DMD2_TIMESTEPS — " + "this test would be vacuous." + ) + + # After DMD2 forward() — scheduler.timesteps must be DMD2_TIMESTEPS + # regardless of the num_steps the caller passed (40 here). + with patch.object(parent, "forward", _fake_parent_forward): + pipeline.forward(_make_request(), num_inference_steps=40) + + assert pipeline.scheduler.timesteps.long().tolist() == pipeline.DMD2_TIMESTEPS + + +def test_forward_timesteps_fixed_across_num_steps(pipeline): + """DMD2 timesteps are always the same regardless of what num_steps the caller passes.""" + parent = type(pipeline).__bases__[0] + + for num_steps in [1, 4, 10, 40, 100]: + with patch.object(parent, "forward", _fake_parent_forward): + pipeline.forward(_make_request(), num_inference_steps=num_steps) + + assert pipeline.scheduler.timesteps.long().tolist() == pipeline.DMD2_TIMESTEPS, ( + f"num_steps={num_steps}: scheduler.timesteps {pipeline.scheduler.timesteps.tolist()} " + f"!= DMD2_TIMESTEPS {pipeline.DMD2_TIMESTEPS}" + ) + + From 3176727fe3e78be0913c1b772d3920da3ff38462 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 16 Apr 2026 11:53:18 -0700 Subject: [PATCH 4/7] chore: mixin based architecture + fixes Signed-off-by: ayushag --- .../test_wan_dmd2_request_sanitization.py | 111 +++++++----- .../models/wan2_2/test_wan_dmd2_scheduler.py | 96 +++++----- .../models/wan2_2/pipeline_wan2_2.py | 170 +++++++++--------- .../models/wan2_2/pipeline_wan2_2_i2v.py | 135 +------------- 4 files changed, 207 insertions(+), 305 deletions(-) diff --git a/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py b/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py index f552bdf3579..474c324518a 100644 --- a/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py +++ b/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py @@ -1,28 +1,41 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock, patch + import pytest +import torch -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import WanT2VDMD2Pipeline -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import WanI2VDMD2Pipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams pytestmark = [pytest.mark.core_model, pytest.mark.cpu] +# Wan base pipeline whose __init__ loads model weights — mocked in tests. +_WAN_BASE = { + WanT2VDMD2Pipeline: Wan22Pipeline, + WanI2VDMD2Pipeline: Wan22I2VPipeline, +} -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- def _make_pipeline(cls): - """Instantiate a DMD2 pipeline without loading any model weights.""" - import torch - pipeline = object.__new__(cls) - torch.nn.Module.__init__(pipeline) + """Run the DMD2 __init__ with the Wan base mocked out (no model weights loaded).""" + + base = _WAN_BASE[cls] + od_config = MagicMock() + od_config.model = "/nonexistent" + + def _mock_base_init(self, *a, **kw): + self.od_config = od_config + + with patch.object(base, "__init__", _mock_base_init): + pipeline = object.__new__(cls) + torch.nn.Module.__init__(pipeline) + cls.__init__(pipeline, od_config=od_config) return pipeline def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest: - """Build a minimal OmniDiffusionRequest with given sampling params.""" sp = OmniDiffusionSamplingParams(**sp_kwargs) return OmniDiffusionRequest( prompts=prompts or [{"prompt": "a cat dancing"}], @@ -35,93 +48,106 @@ def pipeline(request): return _make_pipeline(request.param) +# --------------------------------------------------------------------------- +# num_inference_steps +# --------------------------------------------------------------------------- + + +def test_num_inference_steps_forced_to_dmd2_value(pipeline): + req = _make_request(num_inference_steps=40) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps + + +def test_num_inference_steps_already_correct(pipeline): + req = _make_request(num_inference_steps=pipeline.num_inference_steps) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps + + # --------------------------------------------------------------------------- # guidance_scale # --------------------------------------------------------------------------- + def test_guidance_scale_forced_to_one(pipeline): req = _make_request(guidance_scale=5.0, guidance_scale_provided=True) - pipeline._verify_dmd2_request(req) - assert req.sampling_params.guidance_scale == 1.0 + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale assert req.sampling_params.guidance_scale_provided is False def test_guidance_scale_already_correct(pipeline): - req = _make_request(guidance_scale=1.0, guidance_scale_provided=False) - pipeline._verify_dmd2_request(req) - assert req.sampling_params.guidance_scale == 1.0 + req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=False) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale def test_guidance_scale_provided_flag_cleared(pipeline): - """guidance_scale_provided=True must be cleared even if scale is already 1.0.""" - req = _make_request(guidance_scale=1.0, guidance_scale_provided=True) - pipeline._verify_dmd2_request(req) + """guidance_scale_provided=True must be cleared even if scale is already dmd2_guidance_scale.""" + req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=True) + pipeline._sanitize_dmd2_request(req) assert req.sampling_params.guidance_scale_provided is False + def test_guidance_scale_2_cleared(pipeline): req = _make_request(guidance_scale_2=3.0) - pipeline._verify_dmd2_request(req) + pipeline._sanitize_dmd2_request(req) assert req.sampling_params.guidance_scale_2 is None def test_guidance_scale_2_unset_unchanged(pipeline): req = _make_request() - pipeline._verify_dmd2_request(req) + pipeline._sanitize_dmd2_request(req) assert req.sampling_params.guidance_scale_2 is None -# --------------------------------------------------------------------------- -# CFG flags -# --------------------------------------------------------------------------- - def test_true_cfg_scale_cleared(pipeline): req = _make_request(true_cfg_scale=2.0) - pipeline._verify_dmd2_request(req) + pipeline._sanitize_dmd2_request(req) assert req.sampling_params.true_cfg_scale is None def test_do_classifier_free_guidance_forced_false(pipeline): req = _make_request(do_classifier_free_guidance=True) - pipeline._verify_dmd2_request(req) + pipeline._sanitize_dmd2_request(req) assert req.sampling_params.do_classifier_free_guidance is False def test_is_cfg_negative_forced_false(pipeline): req = _make_request(is_cfg_negative=True) - pipeline._verify_dmd2_request(req) + pipeline._sanitize_dmd2_request(req) assert req.sampling_params.is_cfg_negative is False -# --------------------------------------------------------------------------- -# negative_prompt in prompt dict -# --------------------------------------------------------------------------- - def test_negative_prompt_stripped_from_prompt_dict(pipeline): req = _make_request(prompts=[{"prompt": "a cat", "negative_prompt": "blurry"}]) - pipeline._verify_dmd2_request(req) + pipeline._sanitize_dmd2_request(req) assert "negative_prompt" not in req.prompts[0] assert req.prompts[0]["prompt"] == "a cat" def test_no_negative_prompt_unchanged(pipeline): req = _make_request(prompts=[{"prompt": "a cat"}]) - pipeline._verify_dmd2_request(req) + pipeline._sanitize_dmd2_request(req) assert req.prompts[0] == {"prompt": "a cat"} def test_string_prompt_not_mutated(pipeline): """String prompts (not dicts) must pass through unchanged.""" req = _make_request(prompts=["a cat dancing"]) - pipeline._verify_dmd2_request(req) + pipeline._sanitize_dmd2_request(req) assert req.prompts == ["a cat dancing"] def test_multiple_prompts_all_sanitized(pipeline): - req = _make_request(prompts=[ - {"prompt": "a cat", "negative_prompt": "blurry"}, - {"prompt": "a dog", "negative_prompt": "ugly"}, - ]) - pipeline._verify_dmd2_request(req) + req = _make_request( + prompts=[ + {"prompt": "a cat", "negative_prompt": "blurry"}, + {"prompt": "a dog", "negative_prompt": "ugly"}, + ] + ) + pipeline._sanitize_dmd2_request(req) for p in req.prompts: assert "negative_prompt" not in p @@ -130,15 +156,16 @@ def test_multiple_prompts_all_sanitized(pipeline): # Clean request — nothing changes # --------------------------------------------------------------------------- + def test_clean_request_no_changes(pipeline): req = _make_request( - guidance_scale=1.0, + guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=False, do_classifier_free_guidance=False, is_cfg_negative=False, ) - pipeline._verify_dmd2_request(req) - assert req.sampling_params.guidance_scale == 1.0 + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale assert req.sampling_params.guidance_scale_provided is False assert req.sampling_params.guidance_scale_2 is None assert req.sampling_params.true_cfg_scale is None diff --git a/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py b/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py index 011baafc8ba..99b097f2a07 100644 --- a/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py +++ b/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py @@ -1,96 +1,88 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import inspect from unittest.mock import MagicMock, patch import pytest -from diffusers import FlowMatchEulerDiscreteScheduler +import torch -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import WanT2VDMD2Pipeline -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import WanI2VDMD2Pipeline -from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -def _make_pipeline(cls): - """ - Instantiate a DMD2 pipeline letting __init__ run — but mocking the - parent's __init__ so no model weights are loaded. - - This verifies that the DMD2 class itself (not the test helper) sets up - the correct scheduler. - """ - import torch - parent = cls.__bases__[0] # Wan22Pipeline or Wan22I2VPipeline - with patch.object(parent, "__init__", lambda *a, **kw: None): - pipeline = object.__new__(cls) - torch.nn.Module.__init__(pipeline) - cls.__init__(pipeline, od_config=MagicMock()) - return pipeline +_DMD2_TIMESTEPS = [999, 937, 833, 624] +# Wan base pipeline whose __init__ loads model weights — mocked in tests. +_WAN_BASE = { + WanT2VDMD2Pipeline: Wan22Pipeline, + WanI2VDMD2Pipeline: Wan22I2VPipeline, +} -@pytest.fixture(params=[WanT2VDMD2Pipeline, WanI2VDMD2Pipeline], ids=["t2v", "i2v"]) -def pipeline(request): - return _make_pipeline(request.param) +def _make_pipeline(cls): + """Run the DMD2 __init__ (including __init_dmd2__) with the Wan base mocked.""" + base = _WAN_BASE[cls] + od_config = MagicMock() + od_config.model = "/nonexistent" # _load_model_index returns {} → uses inline defaults -from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams + def _mock_base_init(self, *a, **kw): + self.od_config = od_config # __init_dmd2__ needs this + + with patch.object(base, "__init__", _mock_base_init): + pipeline = object.__new__(cls) + torch.nn.Module.__init__(pipeline) + cls.__init__(pipeline, od_config=od_config) + return pipeline def _make_request(**sp_kwargs) -> OmniDiffusionRequest: sp = OmniDiffusionSamplingParams(**sp_kwargs) - return OmniDiffusionRequest( - prompts=[{"prompt": "a cat"}], - sampling_params=sp, - ) + return OmniDiffusionRequest(prompts=[{"prompt": "a cat"}], sampling_params=sp) +@pytest.fixture(params=[WanT2VDMD2Pipeline, WanI2VDMD2Pipeline], ids=["t2v", "i2v"]) +def pipeline(request): + return _make_pipeline(request.param) -def test_scheduler_is_euler(pipeline): - """DMD2 __init__ must replace the parent's UniPC with Euler.""" - assert isinstance(pipeline.scheduler, FlowMatchEulerDiscreteScheduler) +# --------------------------------------------------------------------------- +# forward() timestep injection +# --------------------------------------------------------------------------- def _fake_parent_forward(self, req, *args, num_inference_steps=40, **kwargs): - """Minimal parent forward() stub: calls set_timesteps exactly as the real parent does.""" + """Stub that calls set_timesteps as the real parent does.""" self.scheduler.set_timesteps(num_inference_steps, device="cpu") return MagicMock() + def test_forward_timesteps_match_dmd2_schedule(pipeline): - """ - After forward() runs, scheduler.timesteps must equal DMD2_TIMESTEPS. - """ - parent = type(pipeline).__bases__[0] + """After forward() runs, scheduler.timesteps must equal the DMD2 training schedule.""" + parent = _WAN_BASE[type(pipeline)] - # Baseline: Euler scheduler with num_steps=40 gives a different schedule + # Baseline: calling set_timesteps(40) without the DMD2 override gives a different schedule pipeline.scheduler.set_timesteps(40, device="cpu") default_timesteps = pipeline.scheduler.timesteps.long().tolist() - assert default_timesteps != pipeline.DMD2_TIMESTEPS, ( - "Euler scheduler default 40-step schedule unexpectedly matches DMD2_TIMESTEPS — " - "this test would be vacuous." + assert default_timesteps == _DMD2_TIMESTEPS, ( + "DMD2EulerScheduler should always return DMD2 timesteps regardless of num_steps" ) - # After DMD2 forward() — scheduler.timesteps must be DMD2_TIMESTEPS - # regardless of the num_steps the caller passed (40 here). with patch.object(parent, "forward", _fake_parent_forward): - pipeline.forward(_make_request(), num_inference_steps=40) + pipeline.forward(_make_request()) - assert pipeline.scheduler.timesteps.long().tolist() == pipeline.DMD2_TIMESTEPS + assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS def test_forward_timesteps_fixed_across_num_steps(pipeline): - """DMD2 timesteps are always the same regardless of what num_steps the caller passes.""" - parent = type(pipeline).__bases__[0] + """scheduler.timesteps is always the DMD2 schedule regardless of num_steps passed.""" + parent = _WAN_BASE[type(pipeline)] for num_steps in [1, 4, 10, 40, 100]: with patch.object(parent, "forward", _fake_parent_forward): - pipeline.forward(_make_request(), num_inference_steps=num_steps) + pipeline.forward(_make_request()) - assert pipeline.scheduler.timesteps.long().tolist() == pipeline.DMD2_TIMESTEPS, ( - f"num_steps={num_steps}: scheduler.timesteps {pipeline.scheduler.timesteps.tolist()} " - f"!= DMD2_TIMESTEPS {pipeline.DMD2_TIMESTEPS}" + assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS, ( + f"num_steps={num_steps}: got {pipeline.scheduler.timesteps.tolist()}" ) - - diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index b8e6485dd9e..b8d1a3d5067 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -880,109 +880,117 @@ def check_inputs( # --------------------------------------------------------------------------- -class WanT2VDMD2Pipeline(Wan22Pipeline): - """Wan 2.1 T2V pipeline for FastGen DMD2-distilled 4-step models. - - Three changes from Wan22Pipeline: - - FlowMatchEulerDiscreteScheduler(shift=1.0): σ=t exactly, matching the - RFNoiseSchedule used during student training. - - Training timesteps [999, 937, 833, 624] injected via set_timesteps patch. - Euler's default 4-step schedule gives a different distribution. - - guidance_scale and negative_prompt are sanitized: teacher CFG is baked - into student weights so any user-supplied CFG is silently overridden. - """ - - GUIDANCE_SCALE = 1.0 - NUM_INFERENCE_STEPS = 4 - DMD2_TIMESTEPS = [999, 937, 833, 624] +def _load_model_index(model: str, local_files_only: bool) -> dict: + """Load model_index.json from local path or HF Hub.""" + if local_files_only: + model_index_path = os.path.join(model, "model_index.json") + if os.path.exists(model_index_path): + with open(model_index_path) as f: + return json.load(f) + else: + try: + from huggingface_hub import hf_hub_download - def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): - super().__init__(od_config=od_config, prefix=prefix) - self.scheduler = FlowMatchEulerDiscreteScheduler( + model_index_path = hf_hub_download(model, "model_index.json") + with open(model_index_path) as f: + return json.load(f) + except Exception: + pass + return {} + + +class DMD2EulerScheduler(FlowMatchEulerDiscreteScheduler): + """Euler scheduler that always uses the fixed DMD2 training timestep schedule.""" + + def __init__(self, *args, dmd2_timesteps: list[int], **kwargs): + super().__init__(*args, **kwargs) + self._dmd2_timesteps = dmd2_timesteps + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + **kwargs, + ) -> None: + super().set_timesteps(timesteps=self._dmd2_timesteps, device=device) + + +class DMD2PipelineMixin: + """Mixin for FastGen DMD2-distilled models. Must appear before the base pipeline in MRO.""" + + def __init_dmd2__(self) -> None: + """Call after super().__init__() to apply DMD2 scheduler and read model_index.""" + local_files_only = os.path.exists(self.od_config.model) + model_index = _load_model_index(self.od_config.model, local_files_only) + + dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624]) + self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4) + shift = model_index.get("dmd2_scheduler_shift", 1.0) + self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0) + + self.scheduler = DMD2EulerScheduler( num_train_timesteps=1000, - shift=1.0, + shift=shift, + dmd2_timesteps=dmd2_timesteps, ) - def _verify_dmd2_request(self, req: OmniDiffusionRequest) -> None: - """Sanitize CFG-related fields from the request. - - DMD2 student weights have teacher CFG baked in. Any CFG-triggering - field supplied by the caller would double-guide the model and degrade - quality. We override all of them with a single warning per field. - """ + def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None: + """Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts.""" sp = req.sampling_params - if sp.guidance_scale_provided and sp.guidance_scale != self.GUIDANCE_SCALE: + if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps: + logger.warning( + "DMD2: ignoring num_inference_steps=%d, forcing %d.", + sp.num_inference_steps, + self.num_inference_steps, + ) + sp.num_inference_steps = self.num_inference_steps + + if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale: logger.warning( - "DMD2: ignoring guidance_scale=%.2f — CFG is baked into student weights. Forcing guidance_scale=%.2f.", + "DMD2: ignoring guidance_scale=%.2f, forcing %.2f.", sp.guidance_scale, - self.GUIDANCE_SCALE, + self.dmd2_guidance_scale, ) - sp.guidance_scale = self.GUIDANCE_SCALE + sp.guidance_scale = self.dmd2_guidance_scale sp.guidance_scale_provided = False if sp.guidance_scale_2 is not None: - logger.warning("DMD2: ignoring guidance_scale_2 — not supported.") + logger.warning("DMD2: ignoring guidance_scale_2.") sp.guidance_scale_2 = None - # Classifier free guidance is baked in the student weights, set all related params to None if sp.true_cfg_scale is not None: - logger.warning("DMD2: ignoring true_cfg_scale — not supported.") + logger.warning("DMD2: ignoring true_cfg_scale.") sp.true_cfg_scale = None sp.do_classifier_free_guidance = False sp.is_cfg_negative = False - # negative_prompt comes from req.prompts, not a forward() param. - # Strip it so the parent's encode_prompt never encodes an uncond embedding. - fixed_prompts = [] + fixed = [] for p in req.prompts: - if isinstance(p, dict) and p.get("negative_prompt"): - logger.warning("DMD2: ignoring negative_prompt — not supported for DMD2 models.") + if isinstance(p, dict) and "negative_prompt" in p: + logger.warning("DMD2: ignoring negative_prompt.") p = {k: v for k, v in p.items() if k != "negative_prompt"} - fixed_prompts.append(p) - req.prompts = fixed_prompts - - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | None = None, - height: int = 480, - width: int = 832, - num_inference_steps: int = 4, - guidance_scale: float | tuple[float, float] = 1.0, - frame_num: int = 81, - output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, - prompt_embeds: torch.Tensor | None = None, - attention_kwargs: dict | None = None, - **kwargs, - ) -> DiffusionOutput: - # DMD2: teacher CFG is baked into student weights. - # negative_prompt comes from req.prompts — sanitize it there. - # guidance_scale is forced to 1.0 regardless of what the caller passed. - self._verify_dmd2_request(req) + fixed.append(p) + req.prompts = fixed + + def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: + self._sanitize_dmd2_request(req) + # Safety: remove DMD2-controlled params from kwargs to avoid TypeError + # if a caller passes them explicitly alongside **kwargs. + kwargs.pop("guidance_scale", None) + kwargs.pop("num_inference_steps", None) + return super().forward( + req, + guidance_scale=self.dmd2_guidance_scale, + num_inference_steps=self.num_inference_steps, + **kwargs, + ) - _orig = self.scheduler.set_timesteps - def _dmd2_set_timesteps(num_steps, device=None, **kw): - _orig(timesteps=self.DMD2_TIMESTEPS, device=device) +class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline): + """Wan 2.1 T2V pipeline for FastGen DMD2-distilled 4-step models.""" - self.scheduler.set_timesteps = _dmd2_set_timesteps - try: - return super().forward( - req, - prompt=prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=1.0, - frame_num=frame_num, - output_type=output_type, - generator=generator, - prompt_embeds=prompt_embeds, - attention_kwargs=attention_kwargs, - **kwargs, - ) - finally: - self.scheduler.set_timesteps = _orig + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 482791e7972..e8c61a6310b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -12,7 +12,6 @@ import numpy as np import PIL.Image import torch -from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel @@ -27,6 +26,8 @@ from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + DMD2PipelineMixin, + _load_model_index, create_transformer_from_config, load_transformer_config, retrieve_latents, @@ -40,29 +41,6 @@ DEBUG_PERF = False -def _load_model_index(model: str, local_files_only: bool) -> dict: - """Load model_index.json from local path or HF Hub.""" - if local_files_only: - model_index_path = os.path.join(model, "model_index.json") - if os.path.exists(model_index_path): - import json - - with open(model_index_path) as f: - return json.load(f) - else: - try: - import json - - from huggingface_hub import hf_hub_download - - model_index_path = hf_hub_download(model, "model_index.json") - with open(model_index_path) as f: - return json.load(f) - except Exception: - pass - return {} - - def get_wan22_i2v_post_process_func( od_config: OmniDiffusionConfig, ): @@ -859,112 +837,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # --------------------------------------------------------------------------- -class WanI2VDMD2Pipeline(Wan22I2VPipeline): - """Wan 2.1 I2V pipeline for FastGen DMD2-distilled 4-step models. - - Three changes from Wan22I2VPipeline: - - FlowMatchEulerDiscreteScheduler(shift=1.0): σ=t exactly, matching the - RFNoiseSchedule used during student training. shift>1 distorts sigmas and - causes noisy output (shift=5.0 → final step dt=-0.892 vs correct -0.624). - - Training timesteps [999, 937, 833, 624] injected via set_timesteps patch. - Euler's default 4-step schedule gives a different distribution. - - guidance_scale and negative_prompt are sanitized: teacher CFG is baked - into student weights so any user-supplied CFG is silently overridden. - """ - - GUIDANCE_SCALE = 1.0 - NUM_INFERENCE_STEPS = 4 - DMD2_TIMESTEPS = [999, 937, 833, 624] +class WanI2VDMD2Pipeline(DMD2PipelineMixin, Wan22I2VPipeline): + """Wan 2.1 I2V pipeline for FastGen DMD2-distilled 4-step models.""" def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): super().__init__(od_config=od_config, prefix=prefix) - self.scheduler = FlowMatchEulerDiscreteScheduler( - num_train_timesteps=1000, - shift=1.0, - ) - - def _verify_dmd2_request(self, req: OmniDiffusionRequest) -> None: - """Sanitize CFG-related fields from the request. - - DMD2 student weights have teacher CFG baked in. Any CFG-triggering - field supplied by the caller would double-guide the model and degrade - quality. We override all of them with a single warning per field. - """ - sp = req.sampling_params - - if sp.guidance_scale_provided and sp.guidance_scale != self.GUIDANCE_SCALE: - logger.warning( - "DMD2: ignoring guidance_scale=%.2f — CFG is baked into student weights. Forcing guidance_scale=%.2f.", - sp.guidance_scale, - self.GUIDANCE_SCALE, - ) - sp.guidance_scale = self.GUIDANCE_SCALE - sp.guidance_scale_provided = False - - if sp.guidance_scale_2 is not None: - logger.warning("DMD2: ignoring guidance_scale_2 — not supported.") - sp.guidance_scale_2 = None - - if sp.true_cfg_scale is not None: - logger.warning("DMD2: ignoring true_cfg_scale — not supported.") - sp.true_cfg_scale = None - - sp.do_classifier_free_guidance = False - sp.is_cfg_negative = False - - # negative_prompt comes from req.prompts, not a forward() param. - # Strip it so the parent's encode_prompt never encodes an uncond embedding. - fixed_prompts = [] - for p in req.prompts: - if isinstance(p, dict) and p.get("negative_prompt"): - logger.warning("DMD2: ignoring negative_prompt — not supported for DMD2 models.") - p = {k: v for k, v in p.items() if k != "negative_prompt"} - fixed_prompts.append(p) - req.prompts = fixed_prompts - - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | None = None, - image: PIL.Image.Image | torch.Tensor | None = None, - height: int = 480, - width: int = 832, - num_inference_steps: int = 4, - guidance_scale: float | tuple[float, float] = 1.0, - frame_num: int = 81, - output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, - prompt_embeds: torch.Tensor | None = None, - image_embeds: torch.Tensor | None = None, - last_image: PIL.Image.Image | torch.Tensor | None = None, - attention_kwargs: dict | None = None, - **kwargs, - ) -> DiffusionOutput: - self._verify_dmd2_request(req) - - _orig = self.scheduler.set_timesteps - - def _dmd2_set_timesteps(num_steps, device=None, **kw): - _orig(timesteps=self.DMD2_TIMESTEPS, device=device) - - self.scheduler.set_timesteps = _dmd2_set_timesteps - try: - return super().forward( - req, - prompt=prompt, - image=image, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=self.GUIDANCE_SCALE, - frame_num=frame_num, - output_type=output_type, - generator=generator, - prompt_embeds=prompt_embeds, - image_embeds=image_embeds, - last_image=last_image, - attention_kwargs=attention_kwargs, - **kwargs, - ) - finally: - self.scheduler.set_timesteps = _orig + self.__init_dmd2__() From a7fa2b30f2827d6d0025a506e44f919fff08d8be Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 16 Apr 2026 16:56:51 -0700 Subject: [PATCH 5/7] chore: unified extensible structure Signed-off-by: ayushag --- vllm_omni/diffusion/models/dmd2/__init__.py | 8 ++ vllm_omni/diffusion/models/dmd2/mixin.py | 87 ++++++++++++++++++ .../diffusion/models/schedulers/__init__.py | 2 + .../schedulers/scheduling_dmd2_euler.py | 23 +++++ .../models/wan2_2/pipeline_wan2_2.py | 91 +------------------ .../models/wan2_2/pipeline_wan2_2_i2v.py | 2 +- 6 files changed, 122 insertions(+), 91 deletions(-) create mode 100644 vllm_omni/diffusion/models/dmd2/__init__.py create mode 100644 vllm_omni/diffusion/models/dmd2/mixin.py create mode 100644 vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py diff --git a/vllm_omni/diffusion/models/dmd2/__init__.py b/vllm_omni/diffusion/models/dmd2/__init__.py new file mode 100644 index 00000000000..d0c8219d4d1 --- /dev/null +++ b/vllm_omni/diffusion/models/dmd2/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.dmd2.mixin import DMD2PipelineMixin + +__all__ = [ + "DMD2PipelineMixin", +] diff --git a/vllm_omni/diffusion/models/dmd2/mixin.py b/vllm_omni/diffusion/models/dmd2/mixin.py new file mode 100644 index 00000000000..29dadf0a0a5 --- /dev/null +++ b/vllm_omni/diffusion/models/dmd2/mixin.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import logging +import os + +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = logging.getLogger(__name__) + + +class DMD2PipelineMixin: + """Mixin for FastGen DMD2-distilled models. Must appear before the base pipeline in MRO.""" + + def __init_dmd2__(self) -> None: + """Call after super().__init__() to apply DMD2 scheduler and read model_index.""" + # Deferred import: avoids cycle with wan2_2.pipeline_wan2_2 (which imports this mixin). + from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import _load_model_index + + local_files_only = os.path.exists(self.od_config.model) + model_index = _load_model_index(self.od_config.model, local_files_only) + + dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624]) + self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4) + shift = model_index.get("dmd2_scheduler_shift", 1.0) + self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0) + + self.scheduler = DMD2EulerScheduler( + num_train_timesteps=1000, + shift=shift, + dmd2_timesteps=dmd2_timesteps, + ) + + def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None: + """Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts.""" + sp = req.sampling_params + + if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps: + logger.warning( + "DMD2: ignoring num_inference_steps=%d, forcing %d.", + sp.num_inference_steps, + self.num_inference_steps, + ) + sp.num_inference_steps = self.num_inference_steps + + if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale: + logger.warning( + "DMD2: ignoring guidance_scale=%.2f, forcing %.2f.", + sp.guidance_scale, + self.dmd2_guidance_scale, + ) + sp.guidance_scale = self.dmd2_guidance_scale + sp.guidance_scale_provided = False + + if sp.guidance_scale_2 is not None: + logger.warning("DMD2: ignoring guidance_scale_2.") + sp.guidance_scale_2 = None + + if sp.true_cfg_scale is not None: + logger.warning("DMD2: ignoring true_cfg_scale.") + sp.true_cfg_scale = None + + sp.do_classifier_free_guidance = False + sp.is_cfg_negative = False + + fixed = [] + for p in req.prompts: + if isinstance(p, dict) and "negative_prompt" in p: + logger.warning("DMD2: ignoring negative_prompt.") + p = {k: v for k, v in p.items() if k != "negative_prompt"} + fixed.append(p) + req.prompts = fixed + + def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: + self._sanitize_dmd2_request(req) + kwargs.pop("guidance_scale", None) + kwargs.pop("num_inference_steps", None) + return super().forward( + req, + guidance_scale=self.dmd2_guidance_scale, + num_inference_steps=self.num_inference_steps, + **kwargs, + ) diff --git a/vllm_omni/diffusion/models/schedulers/__init__.py b/vllm_omni/diffusion/models/schedulers/__init__.py index 6f8df78ebf0..e683ed27203 100644 --- a/vllm_omni/diffusion/models/schedulers/__init__.py +++ b/vllm_omni/diffusion/models/schedulers/__init__.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm_omni.diffusion.models.schedulers.scheduling_dmd2_euler import DMD2EulerScheduler from vllm_omni.diffusion.models.schedulers.scheduling_flow_unipc_multistep import ( FlowUniPCMultistepScheduler, ) __all__ = [ + "DMD2EulerScheduler", "FlowUniPCMultistepScheduler", ] diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py new file mode 100644 index 00000000000..01447a41d77 --- /dev/null +++ b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch +from diffusers import FlowMatchEulerDiscreteScheduler + + +class DMD2EulerScheduler(FlowMatchEulerDiscreteScheduler): + """Euler scheduler that always uses the fixed DMD2 training timestep schedule.""" + + def __init__(self, *args, dmd2_timesteps: list[int], **kwargs): + super().__init__(*args, **kwargs) + self._dmd2_timesteps = dmd2_timesteps + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + **kwargs, + ) -> None: + super().set_timesteps(timesteps=self._dmd2_timesteps, device=device) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 9352d3354a3..349cc34b369 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -12,7 +12,6 @@ import PIL.Image import torch -from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, UMT5EncoderModel @@ -23,6 +22,7 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler @@ -957,95 +957,6 @@ def _load_model_index(model: str, local_files_only: bool) -> dict: return {} -class DMD2EulerScheduler(FlowMatchEulerDiscreteScheduler): - """Euler scheduler that always uses the fixed DMD2 training timestep schedule.""" - - def __init__(self, *args, dmd2_timesteps: list[int], **kwargs): - super().__init__(*args, **kwargs) - self._dmd2_timesteps = dmd2_timesteps - - def set_timesteps( - self, - num_inference_steps: int | None = None, - device: str | torch.device | None = None, - **kwargs, - ) -> None: - super().set_timesteps(timesteps=self._dmd2_timesteps, device=device) - - -class DMD2PipelineMixin: - """Mixin for FastGen DMD2-distilled models. Must appear before the base pipeline in MRO.""" - - def __init_dmd2__(self) -> None: - """Call after super().__init__() to apply DMD2 scheduler and read model_index.""" - local_files_only = os.path.exists(self.od_config.model) - model_index = _load_model_index(self.od_config.model, local_files_only) - - dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624]) - self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4) - shift = model_index.get("dmd2_scheduler_shift", 1.0) - self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0) - - self.scheduler = DMD2EulerScheduler( - num_train_timesteps=1000, - shift=shift, - dmd2_timesteps=dmd2_timesteps, - ) - - def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None: - """Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts.""" - sp = req.sampling_params - - if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps: - logger.warning( - "DMD2: ignoring num_inference_steps=%d, forcing %d.", - sp.num_inference_steps, - self.num_inference_steps, - ) - sp.num_inference_steps = self.num_inference_steps - - if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale: - logger.warning( - "DMD2: ignoring guidance_scale=%.2f, forcing %.2f.", - sp.guidance_scale, - self.dmd2_guidance_scale, - ) - sp.guidance_scale = self.dmd2_guidance_scale - sp.guidance_scale_provided = False - - if sp.guidance_scale_2 is not None: - logger.warning("DMD2: ignoring guidance_scale_2.") - sp.guidance_scale_2 = None - - if sp.true_cfg_scale is not None: - logger.warning("DMD2: ignoring true_cfg_scale.") - sp.true_cfg_scale = None - - sp.do_classifier_free_guidance = False - sp.is_cfg_negative = False - - fixed = [] - for p in req.prompts: - if isinstance(p, dict) and "negative_prompt" in p: - logger.warning("DMD2: ignoring negative_prompt.") - p = {k: v for k, v in p.items() if k != "negative_prompt"} - fixed.append(p) - req.prompts = fixed - - def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: - self._sanitize_dmd2_request(req) - # Safety: remove DMD2-controlled params from kwargs to avoid TypeError - # if a caller passes them explicitly alongside **kwargs. - kwargs.pop("guidance_scale", None) - kwargs.pop("num_inference_steps", None) - return super().forward( - req, - guidance_scale=self.dmd2_guidance_scale, - num_inference_steps=self.num_inference_steps, - **kwargs, - ) - - class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline): """Wan 2.1 T2V pipeline for FastGen DMD2-distilled 4-step models.""" diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index d8631867750..0922786ec40 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -22,10 +22,10 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - DMD2PipelineMixin, _load_model_index, build_wan_scheduler, create_transformer_from_config, From 45aea7cc7cb7dfab2e3361502bc3fc10ccf6a294 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 16 Apr 2026 17:06:23 -0700 Subject: [PATCH 6/7] chore: util cleanup Signed-off-by: ayushag --- vllm_omni/diffusion/models/dmd2/mixin.py | 9 ++++---- .../models/magi_human/pipeline_magi_human.py | 15 +------------ vllm_omni/diffusion/models/utils.py | 21 +++++++++++++++++++ .../models/wan2_2/pipeline_wan2_2.py | 19 ----------------- .../models/wan2_2/pipeline_wan2_2_i2v.py | 7 +++++-- 5 files changed, 32 insertions(+), 39 deletions(-) create mode 100644 vllm_omni/diffusion/models/utils.py diff --git a/vllm_omni/diffusion/models/dmd2/mixin.py b/vllm_omni/diffusion/models/dmd2/mixin.py index 29dadf0a0a5..60c4b95baff 100644 --- a/vllm_omni/diffusion/models/dmd2/mixin.py +++ b/vllm_omni/diffusion/models/dmd2/mixin.py @@ -8,6 +8,7 @@ from vllm_omni.diffusion.data import DiffusionOutput from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler +from vllm_omni.diffusion.models.utils import _load_json from vllm_omni.diffusion.request import OmniDiffusionRequest logger = logging.getLogger(__name__) @@ -18,11 +19,11 @@ class DMD2PipelineMixin: def __init_dmd2__(self) -> None: """Call after super().__init__() to apply DMD2 scheduler and read model_index.""" - # Deferred import: avoids cycle with wan2_2.pipeline_wan2_2 (which imports this mixin). - from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import _load_model_index - local_files_only = os.path.exists(self.od_config.model) - model_index = _load_model_index(self.od_config.model, local_files_only) + try: + model_index = _load_json(self.od_config.model, "model_index.json", local_files_only) + except Exception: + model_index = {} dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624]) self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4) diff --git a/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py b/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py index 881c72edc6d..c1abdf91f04 100644 --- a/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py +++ b/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py @@ -48,6 +48,7 @@ ) from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.t5_encoder.t5_gemma_encoder import T5GemmaEncoderModelTP +from vllm_omni.diffusion.models.utils import _load_json from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import ( DiffusionPipelineProfilerMixin, ) @@ -1640,20 +1641,6 @@ def post_process(output): # =========================================================================== -def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict: - """Load a JSON config file from a local path or HuggingFace Hub repo.""" - if local_files_only: - path = os.path.join(model_path, *filename.split("/")) - with open(path) as f: - return json.load(f) - else: - from huggingface_hub import hf_hub_download - - cached = hf_hub_download(repo_id=model_path, filename=filename) - with open(cached) as f: - return json.load(f) - - def _resolve_subdir( model_path: str, subfolder: str, diff --git a/vllm_omni/diffusion/models/utils.py b/vllm_omni/diffusion/models/utils.py new file mode 100644 index 00000000000..ba0d8dda20c --- /dev/null +++ b/vllm_omni/diffusion/models/utils.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import json +import os + + +def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict: + """Load a JSON config file from a local path or HuggingFace Hub repo.""" + if local_files_only: + path = os.path.join(model_path, *filename.split("/")) + with open(path) as f: + return json.load(f) + else: + from huggingface_hub import hf_hub_download + + cached = hf_hub_download(repo_id=model_path, filename=filename) + with open(cached) as f: + return json.load(f) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 349cc34b369..b175d021955 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -938,25 +938,6 @@ def check_inputs( # --------------------------------------------------------------------------- -def _load_model_index(model: str, local_files_only: bool) -> dict: - """Load model_index.json from local path or HF Hub.""" - if local_files_only: - model_index_path = os.path.join(model, "model_index.json") - if os.path.exists(model_index_path): - with open(model_index_path) as f: - return json.load(f) - else: - try: - from huggingface_hub import hf_hub_download - - model_index_path = hf_hub_download(model, "model_index.json") - with open(model_index_path) as f: - return json.load(f) - except Exception: - pass - return {} - - class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline): """Wan 2.1 T2V pipeline for FastGen DMD2-distilled 4-step models.""" diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 0922786ec40..ba5e2f5ce4b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -25,8 +25,8 @@ from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero +from vllm_omni.diffusion.models.utils import _load_json from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - _load_model_index, build_wan_scheduler, create_transformer_from_config, load_transformer_config, @@ -174,7 +174,10 @@ def __init__( ] # Load model_index.json to detect available components - model_index = _load_model_index(model, local_files_only) + try: + model_index = _load_json(model, "model_index.json", local_files_only) + except Exception: + model_index = {} # Check if this is a two-stage model (MoE with transformer_2) self.has_transformer_2 = "transformer_2" in model_index From 1a970c873cc9d6420eb7ca4c9a642b17d8c2ed43 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 16 Apr 2026 17:33:14 -0700 Subject: [PATCH 7/7] chore: ltx2 add Signed-off-by: ayushag --- tests/diffusion/models/dmd2/__init__.py | 0 .../test_dmd2_request_sanitization.py} | 17 ++++++--- .../test_dmd2_scheduler.py} | 37 +++++++++++-------- vllm_omni/diffusion/models/ltx2/__init__.py | 4 ++ .../diffusion/models/ltx2/pipeline_ltx2.py | 9 +++++ .../models/ltx2/pipeline_ltx2_image2video.py | 9 +++++ .../models/wan2_2/pipeline_wan2_2.py | 2 +- .../models/wan2_2/pipeline_wan2_2_i2v.py | 2 +- vllm_omni/diffusion/registry.py | 12 ++++++ 9 files changed, 69 insertions(+), 23 deletions(-) create mode 100644 tests/diffusion/models/dmd2/__init__.py rename tests/diffusion/models/{wan2_2/test_wan_dmd2_request_sanitization.py => dmd2/test_dmd2_request_sanitization.py} (90%) rename tests/diffusion/models/{wan2_2/test_wan_dmd2_scheduler.py => dmd2/test_dmd2_scheduler.py} (70%) diff --git a/tests/diffusion/models/dmd2/__init__.py b/tests/diffusion/models/dmd2/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py similarity index 90% rename from tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py rename to tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py index 474c324518a..e270390bd99 100644 --- a/tests/diffusion/models/wan2_2/test_wan_dmd2_request_sanitization.py +++ b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py @@ -5,23 +5,27 @@ import pytest import torch +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -# Wan base pipeline whose __init__ loads model weights — mocked in tests. -_WAN_BASE = { +# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests). +_DMD2_BASE = { WanT2VDMD2Pipeline: Wan22Pipeline, WanI2VDMD2Pipeline: Wan22I2VPipeline, + LTX2T2VDMD2Pipeline: LTX2Pipeline, + LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline, } def _make_pipeline(cls): - """Run the DMD2 __init__ with the Wan base mocked out (no model weights loaded).""" + """Run the DMD2 __init__ with the base pipeline mocked out (no model weights loaded).""" - base = _WAN_BASE[cls] + base = _DMD2_BASE[cls] od_config = MagicMock() od_config.model = "/nonexistent" @@ -43,7 +47,10 @@ def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest: ) -@pytest.fixture(params=[WanT2VDMD2Pipeline, WanI2VDMD2Pipeline], ids=["t2v", "i2v"]) +@pytest.fixture( + params=list(_DMD2_BASE.keys()), + ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"], +) def pipeline(request): return _make_pipeline(request.param) diff --git a/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py similarity index 70% rename from tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py rename to tests/diffusion/models/dmd2/test_dmd2_scheduler.py index 99b097f2a07..32d00dbf18e 100644 --- a/tests/diffusion/models/wan2_2/test_wan_dmd2_scheduler.py +++ b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py @@ -5,6 +5,8 @@ import pytest import torch +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams @@ -13,19 +15,21 @@ _DMD2_TIMESTEPS = [999, 937, 833, 624] -# Wan base pipeline whose __init__ loads model weights — mocked in tests. -_WAN_BASE = { +# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests). +_DMD2_BASE = { WanT2VDMD2Pipeline: Wan22Pipeline, WanI2VDMD2Pipeline: Wan22I2VPipeline, + LTX2T2VDMD2Pipeline: LTX2Pipeline, + LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline, } def _make_pipeline(cls): - """Run the DMD2 __init__ (including __init_dmd2__) with the Wan base mocked.""" + """Run the DMD2 __init__ (including __init_dmd2__) with the base pipeline mocked.""" - base = _WAN_BASE[cls] + base = _DMD2_BASE[cls] od_config = MagicMock() - od_config.model = "/nonexistent" # _load_model_index returns {} → uses inline defaults + od_config.model = "/nonexistent" def _mock_base_init(self, *a, **kw): self.od_config = od_config # __init_dmd2__ needs this @@ -42,7 +46,10 @@ def _make_request(**sp_kwargs) -> OmniDiffusionRequest: return OmniDiffusionRequest(prompts=[{"prompt": "a cat"}], sampling_params=sp) -@pytest.fixture(params=[WanT2VDMD2Pipeline, WanI2VDMD2Pipeline], ids=["t2v", "i2v"]) +@pytest.fixture( + params=list(_DMD2_BASE.keys()), + ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"], +) def pipeline(request): return _make_pipeline(request.param) @@ -60,7 +67,7 @@ def _fake_parent_forward(self, req, *args, num_inference_steps=40, **kwargs): def test_forward_timesteps_match_dmd2_schedule(pipeline): """After forward() runs, scheduler.timesteps must equal the DMD2 training schedule.""" - parent = _WAN_BASE[type(pipeline)] + parent = _DMD2_BASE[type(pipeline)] # Baseline: calling set_timesteps(40) without the DMD2 override gives a different schedule pipeline.scheduler.set_timesteps(40, device="cpu") @@ -75,14 +82,12 @@ def test_forward_timesteps_match_dmd2_schedule(pipeline): assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS -def test_forward_timesteps_fixed_across_num_steps(pipeline): - """scheduler.timesteps is always the DMD2 schedule regardless of num_steps passed.""" - parent = _WAN_BASE[type(pipeline)] +def test_forward_timesteps_idempotent_across_calls(pipeline): + """Successive forward() calls must not cause scheduler state to drift.""" + parent = _DMD2_BASE[type(pipeline)] - for num_steps in [1, 4, 10, 40, 100]: - with patch.object(parent, "forward", _fake_parent_forward): - pipeline.forward(_make_request()) + with patch.object(parent, "forward", _fake_parent_forward): + pipeline.forward(_make_request()) + pipeline.forward(_make_request()) - assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS, ( - f"num_steps={num_steps}: got {pipeline.scheduler.timesteps.tolist()}" - ) + assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS diff --git a/vllm_omni/diffusion/models/ltx2/__init__.py b/vllm_omni/diffusion/models/ltx2/__init__.py index 9f9d70f0106..2a78b61baeb 100644 --- a/vllm_omni/diffusion/models/ltx2/__init__.py +++ b/vllm_omni/diffusion/models/ltx2/__init__.py @@ -4,12 +4,14 @@ from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import ( LTX2Pipeline, + LTX2T2VDMD2Pipeline, LTX2TwoStagesPipeline, create_transformer_from_config, get_ltx2_post_process_func, load_transformer_config, ) from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import ( + LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline, LTX2ImageToVideoTwoStagesPipeline, ) @@ -17,7 +19,9 @@ __all__ = [ "LTX2Pipeline", + "LTX2T2VDMD2Pipeline", "LTX2ImageToVideoPipeline", + "LTX2I2VDMD2Pipeline", "LTX2LatentUpsamplePipeline", "LTX2TwoStagesPipeline", "LTX2ImageToVideoTwoStagesPipeline", diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py index c60b192f0a5..f06ffab165a 100644 --- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py @@ -33,6 +33,7 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.lora.request import LoRARequest @@ -1304,3 +1305,11 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class LTX2T2VDMD2Pipeline(DMD2PipelineMixin, LTX2Pipeline): + """LTX-2 T2V pipeline for FastGen DMD2-distilled models.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py index 65e7454b73f..50a71a54b61 100644 --- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py @@ -25,6 +25,7 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.lora.request import LoRARequest @@ -889,3 +890,11 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class LTX2I2VDMD2Pipeline(DMD2PipelineMixin, LTX2ImageToVideoPipeline): + """LTX-2 I2V pipeline for FastGen DMD2-distilled models.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index b175d021955..fd1dc9e218c 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -939,7 +939,7 @@ def check_inputs( class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline): - """Wan 2.1 T2V pipeline for FastGen DMD2-distilled 4-step models.""" + """Wan 2.x T2V pipeline for FastGen DMD2-distilled models.""" def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): super().__init__(od_config=od_config, prefix=prefix) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index ba5e2f5ce4b..fe4f4cd703b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -860,7 +860,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: class WanI2VDMD2Pipeline(DMD2PipelineMixin, Wan22I2VPipeline): - """Wan 2.1 I2V pipeline for FastGen DMD2-distilled 4-step models.""" + """Wan 2.x I2V pipeline for FastGen DMD2-distilled models.""" def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): super().__init__(od_config=od_config, prefix=prefix) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index fca93d1d736..4001109cc9e 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -83,6 +83,16 @@ "pipeline_ltx2_image2video", "LTX2ImageToVideoTwoStagesPipeline", ), + "LTX2T2VDMD2Pipeline": ( + "ltx2", + "pipeline_ltx2", + "LTX2T2VDMD2Pipeline", + ), + "LTX2I2VDMD2Pipeline": ( + "ltx2", + "pipeline_ltx2_image2video", + "LTX2I2VDMD2Pipeline", + ), "StableAudioPipeline": ( "stable_audio", "pipeline_stable_audio", @@ -367,6 +377,8 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "LTX2TwoStagesPipeline": "get_ltx2_post_process_func", "LTX2ImageToVideoPipeline": "get_ltx2_post_process_func", "LTX2ImageToVideoTwoStagesPipeline": "get_ltx2_post_process_func", + "LTX2T2VDMD2Pipeline": "get_ltx2_post_process_func", + "LTX2I2VDMD2Pipeline": "get_ltx2_post_process_func", "StableAudioPipeline": "get_stable_audio_post_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", "WanT2VDMD2Pipeline": "get_wan22_post_process_func",