From 36c5bc6fff712a5a0f4ffb9982601924268e58c4 Mon Sep 17 00:00:00 2001 From: ayushag Date: Mon, 20 Apr 2026 20:28:42 -0700 Subject: [PATCH 1/5] chore: added dmd2 config Signed-off-by: ayushag --- .../dmd2/test_dmd2_request_sanitization.py | 18 +++++----- .../models/dmd2/test_dmd2_scheduler.py | 3 +- vllm_omni/diffusion/models/dmd2/__init__.py | 2 ++ vllm_omni/diffusion/models/dmd2/config.py | 36 +++++++++++++++++++ vllm_omni/diffusion/models/dmd2/mixin.py | 26 +++++++------- 5 files changed, 61 insertions(+), 24 deletions(-) create mode 100644 vllm_omni/diffusion/models/dmd2/config.py diff --git a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py index e270390bd99..f4f2de966a5 100644 --- a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py +++ b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py @@ -63,13 +63,13 @@ def pipeline(request): def test_num_inference_steps_forced_to_dmd2_value(pipeline): req = _make_request(num_inference_steps=40) pipeline._sanitize_dmd2_request(req) - assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps + assert req.sampling_params.num_inference_steps == pipeline.dmd2_config.num_inference_steps def test_num_inference_steps_already_correct(pipeline): - req = _make_request(num_inference_steps=pipeline.num_inference_steps) + req = _make_request(num_inference_steps=pipeline.dmd2_config.num_inference_steps) pipeline._sanitize_dmd2_request(req) - assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps + assert req.sampling_params.num_inference_steps == pipeline.dmd2_config.num_inference_steps # --------------------------------------------------------------------------- @@ -80,19 +80,19 @@ def test_num_inference_steps_already_correct(pipeline): def test_guidance_scale_forced_to_one(pipeline): req = _make_request(guidance_scale=5.0, guidance_scale_provided=True) pipeline._sanitize_dmd2_request(req) - assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale + assert req.sampling_params.guidance_scale == pipeline.dmd2_config.guidance_scale assert req.sampling_params.guidance_scale_provided is False def test_guidance_scale_already_correct(pipeline): - req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=False) + req = _make_request(guidance_scale=pipeline.dmd2_config.guidance_scale, guidance_scale_provided=False) pipeline._sanitize_dmd2_request(req) - assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale + assert req.sampling_params.guidance_scale == pipeline.dmd2_config.guidance_scale def test_guidance_scale_provided_flag_cleared(pipeline): """guidance_scale_provided=True must be cleared even if scale is already dmd2_guidance_scale.""" - req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=True) + req = _make_request(guidance_scale=pipeline.dmd2_config.guidance_scale, guidance_scale_provided=True) pipeline._sanitize_dmd2_request(req) assert req.sampling_params.guidance_scale_provided is False @@ -166,13 +166,13 @@ def test_multiple_prompts_all_sanitized(pipeline): def test_clean_request_no_changes(pipeline): req = _make_request( - guidance_scale=pipeline.dmd2_guidance_scale, + guidance_scale=pipeline.dmd2_config.guidance_scale, guidance_scale_provided=False, do_classifier_free_guidance=False, is_cfg_negative=False, ) pipeline._sanitize_dmd2_request(req) - assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale + assert req.sampling_params.guidance_scale == pipeline.dmd2_config.guidance_scale assert req.sampling_params.guidance_scale_provided is False assert req.sampling_params.guidance_scale_2 is None assert req.sampling_params.true_cfg_scale is None diff --git a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py index 32d00dbf18e..d92485652db 100644 --- a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py +++ b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py @@ -13,7 +13,8 @@ pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -_DMD2_TIMESTEPS = [999, 937, 833, 624] +# Linspace fallback timesteps for num_inference_steps=4 (the mixin default when model_index is empty). +_DMD2_TIMESTEPS = [999, 749, 499, 249] # DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests). _DMD2_BASE = { diff --git a/vllm_omni/diffusion/models/dmd2/__init__.py b/vllm_omni/diffusion/models/dmd2/__init__.py index d0c8219d4d1..7d284cad473 100644 --- a/vllm_omni/diffusion/models/dmd2/__init__.py +++ b/vllm_omni/diffusion/models/dmd2/__init__.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm_omni.diffusion.models.dmd2.config import DMD2Config from vllm_omni.diffusion.models.dmd2.mixin import DMD2PipelineMixin __all__ = [ + "DMD2Config", "DMD2PipelineMixin", ] diff --git a/vllm_omni/diffusion/models/dmd2/config.py b/vllm_omni/diffusion/models/dmd2/config.py new file mode 100644 index 00000000000..a92724785b9 --- /dev/null +++ b/vllm_omni/diffusion/models/dmd2/config.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass +class DMD2Config: + """Inference-time contract for a FastGen DMD2-distilled checkpoint.""" + + num_inference_steps: int = 4 + denoising_timesteps: list[int] | None = None + solver: str = "ode" + guidance_scale: float = 1.0 + + @classmethod + def from_model_index(cls, model_index: dict) -> DMD2Config: + """Read the `dmd2_config` block from a model_index.json dict. Missing block → defaults.""" + block = model_index.get("dmd2_config", {}) + return cls( + num_inference_steps=block.get("num_inference_steps", cls.num_inference_steps), + denoising_timesteps=block.get("denoising_timesteps"), + solver=block.get("solver", cls.solver), + guidance_scale=block.get("guidance_scale", cls.guidance_scale), + ) + + def resolve_timesteps(self) -> list[int]: + if self.denoising_timesteps is not None: + return list(self.denoising_timesteps) + # Get uniformly spaced timesteps from 999 to 0. + ts = torch.linspace(999, 0, self.num_inference_steps + 1)[:-1] + return ts.to(torch.int32).tolist() diff --git a/vllm_omni/diffusion/models/dmd2/mixin.py b/vllm_omni/diffusion/models/dmd2/mixin.py index 60c4b95baff..a464c7bd3f0 100644 --- a/vllm_omni/diffusion/models/dmd2/mixin.py +++ b/vllm_omni/diffusion/models/dmd2/mixin.py @@ -7,6 +7,7 @@ import os from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.models.dmd2.config import DMD2Config from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler from vllm_omni.diffusion.models.utils import _load_json from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -25,36 +26,33 @@ def __init_dmd2__(self) -> None: except Exception: model_index = {} - dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624]) - self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4) - shift = model_index.get("dmd2_scheduler_shift", 1.0) - self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0) + self.dmd2_config = DMD2Config.from_model_index(model_index) self.scheduler = DMD2EulerScheduler( num_train_timesteps=1000, - shift=shift, - dmd2_timesteps=dmd2_timesteps, + shift=1.0, + dmd2_timesteps=self.dmd2_config.resolve_timesteps(), ) def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None: """Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts.""" sp = req.sampling_params - if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps: + if sp.num_inference_steps and sp.num_inference_steps != self.dmd2_config.num_inference_steps: logger.warning( "DMD2: ignoring num_inference_steps=%d, forcing %d.", sp.num_inference_steps, - self.num_inference_steps, + self.dmd2_config.num_inference_steps, ) - sp.num_inference_steps = self.num_inference_steps + sp.num_inference_steps = self.dmd2_config.num_inference_steps - if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale: + if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_config.guidance_scale: logger.warning( "DMD2: ignoring guidance_scale=%.2f, forcing %.2f.", sp.guidance_scale, - self.dmd2_guidance_scale, + self.dmd2_config.guidance_scale, ) - sp.guidance_scale = self.dmd2_guidance_scale + sp.guidance_scale = self.dmd2_config.guidance_scale sp.guidance_scale_provided = False if sp.guidance_scale_2 is not None: @@ -82,7 +80,7 @@ def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: kwargs.pop("num_inference_steps", None) return super().forward( req, - guidance_scale=self.dmd2_guidance_scale, - num_inference_steps=self.num_inference_steps, + guidance_scale=self.dmd2_config.guidance_scale, + num_inference_steps=self.dmd2_config.num_inference_steps, **kwargs, ) From a73f65c3e9acb5a05b572cb2c326f44bc4100e9e Mon Sep 17 00:00:00 2001 From: ayushag Date: Mon, 20 Apr 2026 20:57:28 -0700 Subject: [PATCH 2/5] chore: added stochastic sampling + defense against scheduler Signed-off-by: ayushag --- .../dmd2/test_dmd2_request_sanitization.py | 29 +++++++++++++++++++ .../models/dmd2/test_dmd2_scheduler.py | 21 ++++++++++++++ vllm_omni/diffusion/models/dmd2/mixin.py | 9 ++++++ .../schedulers/scheduling_dmd2_euler.py | 10 +++++-- 4 files changed, 67 insertions(+), 2 deletions(-) diff --git a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py index f4f2de966a5..9f99abdd61a 100644 --- a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py +++ b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py @@ -164,6 +164,35 @@ def test_multiple_prompts_all_sanitized(pipeline): # --------------------------------------------------------------------------- +def test_sample_solver_stripped_from_extra_args(pipeline): + """[C1] defense: sample_solver must not leak into req for the base pipeline to read.""" + req = _make_request() + req.sampling_params.extra_args = {"sample_solver": "euler"} + pipeline._sanitize_dmd2_request(req) + assert "sample_solver" not in req.sampling_params.extra_args + + +def test_flow_shift_stripped_from_extra_args(pipeline): + """[C1] defense: flow_shift must not leak into req for the base pipeline to read.""" + req = _make_request() + req.sampling_params.extra_args = {"flow_shift": 3.0} + pipeline._sanitize_dmd2_request(req) + assert "flow_shift" not in req.sampling_params.extra_args + + +def test_unrelated_extra_args_preserved(pipeline): + """Sanitizer only strips sample_solver / flow_shift; other extras pass through.""" + req = _make_request() + req.sampling_params.extra_args = {"sample_solver": "euler", "unrelated": 42} + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.extra_args == {"unrelated": 42} + + +# --------------------------------------------------------------------------- +# Clean request — nothing changes +# --------------------------------------------------------------------------- + + def test_clean_request_no_changes(pipeline): req = _make_request( guidance_scale=pipeline.dmd2_config.guidance_scale, diff --git a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py index d92485652db..69aa387f647 100644 --- a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py +++ b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py @@ -83,6 +83,27 @@ def test_forward_timesteps_match_dmd2_schedule(pipeline): assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS +def test_default_solver_is_ode(pipeline): + """Default dmd2_config.solver is 'ode' → scheduler.stochastic_sampling is False.""" + assert pipeline.dmd2_config.solver == "ode" + assert pipeline.scheduler.config.stochastic_sampling is False + + +def test_sde_solver_plumbed_to_scheduler(): + """solver='sde' in model_index → scheduler.stochastic_sampling is True.""" + from vllm_omni.diffusion.models.dmd2 import DMD2Config + from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler + + cfg = DMD2Config.from_model_index({"dmd2_config": {"solver": "sde"}}) + scheduler = DMD2EulerScheduler( + num_train_timesteps=1000, + shift=1.0, + dmd2_timesteps=cfg.resolve_timesteps(), + stochastic_sampling=(cfg.solver == "sde"), + ) + assert scheduler.config.stochastic_sampling is True + + def test_forward_timesteps_idempotent_across_calls(pipeline): """Successive forward() calls must not cause scheduler state to drift.""" parent = _DMD2_BASE[type(pipeline)] diff --git a/vllm_omni/diffusion/models/dmd2/mixin.py b/vllm_omni/diffusion/models/dmd2/mixin.py index a464c7bd3f0..48e5719259c 100644 --- a/vllm_omni/diffusion/models/dmd2/mixin.py +++ b/vllm_omni/diffusion/models/dmd2/mixin.py @@ -32,6 +32,7 @@ def __init_dmd2__(self) -> None: num_train_timesteps=1000, shift=1.0, dmd2_timesteps=self.dmd2_config.resolve_timesteps(), + stochastic_sampling=(self.dmd2_config.solver == "sde"), ) def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None: @@ -66,6 +67,14 @@ def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None: sp.do_classifier_free_guidance = False sp.is_cfg_negative = False + # defense: strip scheduler-override extra_args that would let the base pipeline + # (e.g. Wan22Pipeline.forward) rebuild self.scheduler mid-forward and clobber DMD2EulerScheduler. + extra_args = getattr(sp, "extra_args", None) or {} + for key in ("sample_solver", "flow_shift"): + if key in extra_args: + logger.warning("DMD2: ignoring extra_args.%s.", key) + extra_args.pop(key) + fixed = [] for p in req.prompts: if isinstance(p, dict) and "negative_prompt" in p: diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py index 01447a41d77..2b27718409a 100644 --- a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py +++ b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py @@ -10,8 +10,14 @@ class DMD2EulerScheduler(FlowMatchEulerDiscreteScheduler): """Euler scheduler that always uses the fixed DMD2 training timestep schedule.""" - def __init__(self, *args, dmd2_timesteps: list[int], **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + *args, + dmd2_timesteps: list[int], + stochastic_sampling: bool = False, + **kwargs, + ): + super().__init__(*args, stochastic_sampling=stochastic_sampling, **kwargs) self._dmd2_timesteps = dmd2_timesteps def set_timesteps( From 8bba59548364bbf5db7947b2066f76e9e303ca7e Mon Sep 17 00:00:00 2001 From: ayushag Date: Mon, 20 Apr 2026 21:02:44 -0700 Subject: [PATCH 3/5] chore: add flux and Qwen Image Pipeline Signed-off-by: ayushag --- .../models/dmd2/test_dmd2_request_sanitization.py | 6 +++++- tests/diffusion/models/dmd2/test_dmd2_scheduler.py | 6 +++++- vllm_omni/diffusion/models/flux/__init__.py | 2 ++ vllm_omni/diffusion/models/flux/pipeline_flux.py | 9 +++++++++ vllm_omni/diffusion/models/qwen_image/__init__.py | 2 ++ .../models/qwen_image/pipeline_qwen_image.py | 9 +++++++++ vllm_omni/diffusion/registry.py | 12 ++++++++++++ 7 files changed, 44 insertions(+), 2 deletions(-) diff --git a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py index 9f99abdd61a..f7ada6ebfeb 100644 --- a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py +++ b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py @@ -5,8 +5,10 @@ import pytest import torch +from vllm_omni.diffusion.models.flux.pipeline_flux import FluxDMD2Pipeline, FluxPipeline from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import QwenImageDMD2Pipeline, QwenImagePipeline from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams @@ -19,6 +21,8 @@ WanI2VDMD2Pipeline: Wan22I2VPipeline, LTX2T2VDMD2Pipeline: LTX2Pipeline, LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline, + FluxDMD2Pipeline: FluxPipeline, + QwenImageDMD2Pipeline: QwenImagePipeline, } @@ -49,7 +53,7 @@ def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest: @pytest.fixture( params=list(_DMD2_BASE.keys()), - ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"], + ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v", "flux", "qwen_image"], ) def pipeline(request): return _make_pipeline(request.param) diff --git a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py index 69aa387f647..9e4e1ccc5e7 100644 --- a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py +++ b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py @@ -5,8 +5,10 @@ import pytest import torch +from vllm_omni.diffusion.models.flux.pipeline_flux import FluxDMD2Pipeline, FluxPipeline from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import QwenImageDMD2Pipeline, QwenImagePipeline from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams @@ -22,6 +24,8 @@ WanI2VDMD2Pipeline: Wan22I2VPipeline, LTX2T2VDMD2Pipeline: LTX2Pipeline, LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline, + FluxDMD2Pipeline: FluxPipeline, + QwenImageDMD2Pipeline: QwenImagePipeline, } @@ -49,7 +53,7 @@ def _make_request(**sp_kwargs) -> OmniDiffusionRequest: @pytest.fixture( params=list(_DMD2_BASE.keys()), - ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"], + ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v", "flux", "qwen_image"], ) def pipeline(request): return _make_pipeline(request.param) diff --git a/vllm_omni/diffusion/models/flux/__init__.py b/vllm_omni/diffusion/models/flux/__init__.py index 6b13c4d965b..51a9000a5dc 100644 --- a/vllm_omni/diffusion/models/flux/__init__.py +++ b/vllm_omni/diffusion/models/flux/__init__.py @@ -7,6 +7,7 @@ FluxTransformer2DModel, ) from vllm_omni.diffusion.models.flux.pipeline_flux import ( + FluxDMD2Pipeline, FluxPipeline, get_flux_post_process_func, ) @@ -17,6 +18,7 @@ __all__ = [ "FluxPipeline", + "FluxDMD2Pipeline", "FluxKontextPipeline", "FluxTransformer2DModel", "FluxKontextTransformer2DModel", diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index 70d572d9a65..03bef7bd605 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -25,6 +25,7 @@ from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.flux import FluxTransformer2DModel from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel @@ -665,3 +666,11 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class FluxDMD2Pipeline(DMD2PipelineMixin, FluxPipeline): + """Flux pipeline for FastGen DMD2-distilled models.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/models/qwen_image/__init__.py b/vllm_omni/diffusion/models/qwen_image/__init__.py index 4b823ec75dc..f0aaa1406ae 100644 --- a/vllm_omni/diffusion/models/qwen_image/__init__.py +++ b/vllm_omni/diffusion/models/qwen_image/__init__.py @@ -6,6 +6,7 @@ QwenImageCFGParallelMixin, ) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import ( + QwenImageDMD2Pipeline, QwenImagePipeline, get_qwen_image_post_process_func, ) @@ -16,6 +17,7 @@ __all__ = [ "QwenImageCFGParallelMixin", "QwenImagePipeline", + "QwenImageDMD2Pipeline", "QwenImageTransformer2DModel", "get_qwen_image_post_process_func", ] diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 9ef0cacd5a0..e1034e4109a 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -26,6 +26,7 @@ from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_qwenimage import DistributedAutoencoderKLQwenImage from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, ) @@ -1031,3 +1032,11 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class QwenImageDMD2Pipeline(DMD2PipelineMixin, QwenImagePipeline): + """QwenImage pipeline for FastGen DMD2-distilled models.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 37f5199447c..2a75a075867 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -175,6 +175,16 @@ "pipeline_flux", "FluxPipeline", ), + "FluxDMD2Pipeline": ( + "flux", + "pipeline_flux", + "FluxDMD2Pipeline", + ), + "QwenImageDMD2Pipeline": ( + "qwen_image", + "pipeline_qwen_image", + "QwenImageDMD2Pipeline", + ), "OmniGen2Pipeline": ( "omnigen2", "pipeline_omnigen2", @@ -433,6 +443,8 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "Flux2KleinPipeline": "get_flux2_klein_post_process_func", "NextStep11Pipeline": "get_nextstep11_post_process_func", "FluxPipeline": "get_flux_post_process_func", + "FluxDMD2Pipeline": "get_flux_post_process_func", + "QwenImageDMD2Pipeline": "get_qwen_image_post_process_func", "OmniGen2Pipeline": "get_omnigen2_post_process_func", "HeliosPipeline": "get_helios_post_process_func", "HeliosPyramidPipeline": "get_helios_post_process_func", From 5396849b19356cc1b5cf01bf73874c1f96738f8e Mon Sep 17 00:00:00 2001 From: ayushag Date: Wed, 22 Apr 2026 11:59:23 -0700 Subject: [PATCH 4/5] fix: minor fixes Signed-off-by: ayushag --- .../models/dmd2/test_dmd2_scheduler.py | 21 +++++++++++++++++++ vllm_omni/diffusion/models/dmd2/config.py | 16 +++++++++++--- .../schedulers/scheduling_dmd2_euler.py | 2 ++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py index 9e4e1ccc5e7..7aac36970c8 100644 --- a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py +++ b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py @@ -108,6 +108,27 @@ def test_sde_solver_plumbed_to_scheduler(): assert scheduler.config.stochastic_sampling is True +def test_solver_case_insensitive(): + """'SDE', 'Sde', ' sde ' all normalize to 'sde'.""" + from vllm_omni.diffusion.models.dmd2 import DMD2Config + + for raw in ("SDE", "Sde", " sde ", "sde"): + cfg = DMD2Config.from_model_index({"dmd2_config": {"solver": raw}}) + assert cfg.solver == "sde" + + +def test_solver_invalid_raises(): + """Unknown solver strings raise ValueError with a clear message.""" + import pytest + + from vllm_omni.diffusion.models.dmd2 import DMD2Config + + with pytest.raises(ValueError, match="solver must be one of"): + DMD2Config.from_model_index({"dmd2_config": {"solver": "euler"}}) + with pytest.raises(ValueError, match="solver must be one of"): + DMD2Config(solver="dpmpp") # type: ignore[arg-type] + + def test_forward_timesteps_idempotent_across_calls(pipeline): """Successive forward() calls must not cause scheduler state to drift.""" parent = _DMD2_BASE[type(pipeline)] diff --git a/vllm_omni/diffusion/models/dmd2/config.py b/vllm_omni/diffusion/models/dmd2/config.py index a92724785b9..b4b993fb847 100644 --- a/vllm_omni/diffusion/models/dmd2/config.py +++ b/vllm_omni/diffusion/models/dmd2/config.py @@ -4,9 +4,12 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Literal, get_args import torch +Solver = Literal["ode", "sde"] + @dataclass class DMD2Config: @@ -14,23 +17,30 @@ class DMD2Config: num_inference_steps: int = 4 denoising_timesteps: list[int] | None = None - solver: str = "ode" + solver: Solver = "ode" guidance_scale: float = 1.0 + def __post_init__(self) -> None: + if self.solver not in get_args(Solver): + raise ValueError(f"DMD2Config.solver must be one of {list(get_args(Solver))}, got {self.solver!r}") + @classmethod def from_model_index(cls, model_index: dict) -> DMD2Config: """Read the `dmd2_config` block from a model_index.json dict. Missing block → defaults.""" block = model_index.get("dmd2_config", {}) + solver = block.get("solver", cls.solver) + if isinstance(solver, str): + solver = solver.strip().lower() return cls( num_inference_steps=block.get("num_inference_steps", cls.num_inference_steps), denoising_timesteps=block.get("denoising_timesteps"), - solver=block.get("solver", cls.solver), + solver=solver, guidance_scale=block.get("guidance_scale", cls.guidance_scale), ) def resolve_timesteps(self) -> list[int]: if self.denoising_timesteps is not None: return list(self.denoising_timesteps) - # Get uniformly spaced timesteps from 999 to 0. + # Uniformly spaced timesteps from 999 down toward 0, excluding the final 0. ts = torch.linspace(999, 0, self.num_inference_steps + 1)[:-1] return ts.to(torch.int32).tolist() diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py index 2b27718409a..d09f14af00d 100644 --- a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py +++ b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py @@ -24,6 +24,8 @@ def set_timesteps( self, num_inference_steps: int | None = None, device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ) -> None: super().set_timesteps(timesteps=self._dmd2_timesteps, device=device) From 8803c087e5a32132143f6103c112609fef20ec36 Mon Sep 17 00:00:00 2001 From: ayushag Date: Tue, 5 May 2026 14:42:18 -0700 Subject: [PATCH 5/5] fix: lint Signed-off-by: ayushag --- vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 895aa99afc6..3f200afa9bc 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -26,8 +26,8 @@ from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_qwenimage import DistributedAutoencoderKLQwenImage from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.model_loader.hub_prefetch import prefetch_subfolders +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, )