From cfb19c844611eb317ad843c3edddb01296337448 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 16:10:01 +0800 Subject: [PATCH 01/67] refactor base pipeline and qwen_image_base_pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/base_pipeline.py | 250 ++++++++++++++++++ .../models/qwen_image/pipeline_qwen_image.py | 120 +-------- 2 files changed, 258 insertions(+), 112 deletions(-) create mode 100644 vllm_omni/diffusion/models/base_pipeline.py diff --git a/vllm_omni/diffusion/models/base_pipeline.py b/vllm_omni/diffusion/models/base_pipeline.py new file mode 100644 index 0000000000..24b89c2cea --- /dev/null +++ b/vllm_omni/diffusion/models/base_pipeline.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Base pipeline class for Qwen Image models with shared CFG functionality. +""" + +import torch +from torch import nn + +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) + + +class BasePipeline(nn.Module): + """ + Base class for Diffusion pipelines providing shared CFG methods. + + All pipelines should inherit from this class to reuse + classifier-free guidance logic. + """ + + def predict_noise_maybe_with_cfg( + self, + do_true_cfg, + true_cfg_scale, + positive_kwargs, + negative_kwargs, + cfg_group=None, + cfg_rank=None, + cfg_normalize=True, + output_slice=None, + ): + """ + Predict noise with optional classifier-free guidance. + + Args: + do_true_cfg: Whether to apply CFG + true_cfg_scale: CFG scale factor + positive_kwargs: Kwargs for positive/conditional prediction + negative_kwargs: Kwargs for negative/unconditional prediction + cfg_group: Communication group for CFG parallelism + cfg_rank: Rank in CFG parallel group + cfg_normalize: Whether to normalize CFG output (default: True) + output_slice: If set, slice output to [:, :output_slice] for image editing + + Returns: + Predicted noise tensor + """ + if do_true_cfg: + if cfg_group is not None: + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. + assert cfg_rank is not None, "cfg_rank must be provided if cfg_group is provided" + if cfg_rank == 0: + local_pred = self.predict_noise(**positive_kwargs) + else: + local_pred = self.predict_noise(**negative_kwargs) + + # Slice output for image editing pipelines (remove condition latents) + if output_slice is not None: + local_pred = local_pred[:, :output_slice] + + gathered = cfg_group.all_gather(local_pred, separate_tensors=True) + if cfg_rank == 0: + noise_pred = gathered[0] + neg_noise_pred = gathered[1] + noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) + return noise_pred + else: + # Sequential CFG: compute both positive and negative + positive_noise_pred = self.predict_noise(**positive_kwargs) + negative_noise_pred = self.predict_noise(**negative_kwargs) + + # Slice output for image editing pipelines + if output_slice is not None: + positive_noise_pred = positive_noise_pred[:, :output_slice] + negative_noise_pred = negative_noise_pred[:, :output_slice] + + noise_pred = self.combine_cfg_noise( + positive_noise_pred, negative_noise_pred, true_cfg_scale, cfg_normalize + ) + return noise_pred + else: + # No CFG: only compute positive/conditional prediction + pred = self.predict_noise(**positive_kwargs) + if output_slice is not None: + pred = pred[:, :output_slice] + return pred + + def combine_cfg_noise(self, noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize=True): + """ + Combine conditional and unconditional noise predictions with CFG. + + Args: + noise_pred: Conditional noise prediction + neg_noise_pred: Unconditional noise prediction + true_cfg_scale: CFG scale factor + cfg_normalize: Whether to normalize the combined prediction (default: True) + + Returns: + Combined noise prediction tensor + """ + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + if cfg_normalize: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + noise_pred = comb_pred + + return noise_pred + + def predict_noise(self, *args, **kwargs): + """ + Forward pass through transformer to predict noise. + + Subclasses should override this if they need custom behavior, + but the default implementation calls self.transformer. + """ + return self.transformer(*args, **kwargs)[0] + + def diffuse( + self, + *args, + **kwargs, + ): + """ + Diffusion loop with optional classifier-free guidance. + """ + raise NotImplementedError("Subclasses must implement diffuse") + + @property + def interrupt(self): + """Property to check if diffusion should be interrupted.""" + return getattr(self, "_interrupt", False) + + +class BaseQwenImagePipeline(BasePipeline): + """ + Base class for Qwen Image pipelines providing shared CFG methods. + """ + + def diffuse( + self, + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + latents, + img_shapes, + txt_seq_lens, + negative_txt_seq_lens, + timesteps, + do_true_cfg, + guidance, + true_cfg_scale, + image_latents=None, + cfg_normalize=True, + additional_transformer_kwargs=None, + ): + """ + Diffusion loop with optional classifier-free guidance. + + Args: + prompt_embeds: Positive prompt embeddings + prompt_embeds_mask: Mask for positive prompt + negative_prompt_embeds: Negative prompt embeddings + negative_prompt_embeds_mask: Mask for negative prompt + latents: Noise latents to denoise + img_shapes: Image shape information + txt_seq_lens: Text sequence lengths for positive prompts + negative_txt_seq_lens: Text sequence lengths for negative prompts + timesteps: Diffusion timesteps + do_true_cfg: Whether to apply CFG + guidance: Guidance scale tensor + true_cfg_scale: CFG scale factor + image_latents: Conditional image latents for editing (default: None) + cfg_normalize: Whether to normalize CFG output (default: True) + additional_transformer_kwargs: Extra kwargs to pass to transformer (default: None) + + Returns: + Denoised latents + """ + self.scheduler.set_begin_index(0) + self.transformer.do_true_cfg = do_true_cfg + additional_transformer_kwargs = additional_transformer_kwargs or {} + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + self._current_timestep = t + + # Broadcast timestep to match batch size + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + # Concatenate image latents with noise latents if available (for editing pipelines) + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + cfg_group = get_cfg_group() if cfg_parallel_ready else None + cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": prompt_embeds_mask, + "encoder_hidden_states": prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + **additional_transformer_kwargs, + } + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + **additional_transformer_kwargs, + } + + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + true_cfg_scale, + positive_kwargs, + negative_kwargs, + cfg_group, + cfg_rank, + cfg_normalize, + output_slice, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if cfg_group is not None: + cfg_group.broadcast(latents, src=0) + + return latents diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index e0a37b8bc8..054783b3fe 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -20,18 +20,13 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BaseQwenImagePipeline from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) @@ -239,9 +234,7 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) -class QwenImagePipeline( - nn.Module, -): +class QwenImagePipeline(BaseQwenImagePipeline): def __init__( self, *, @@ -536,109 +529,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def diffuse( - self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - ): - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - self._current_timestep = t - - # Broadcast timestep to match batch size - timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - self.transformer.do_true_cfg = do_true_cfg - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - else: - local_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - - else: - # Forward pass for positive prompt (or unconditional if no CFG) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - # Forward pass for negative prompt (CFG) - if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - return latents - def forward( self, req: OmniDiffusionRequest, @@ -790,6 +680,12 @@ def forward( do_true_cfg, guidance, true_cfg_scale, + image_latents=None, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, ) self._current_timestep = None From ae53dca7f55425df53e69b4eab5dd57f1ed2c3a6 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 16:20:42 +0800 Subject: [PATCH 02/67] other pipelines Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../qwen_image/pipeline_qwen_image_edit.py | 130 +-------------- .../pipeline_qwen_image_edit_plus.py | 125 +-------------- .../qwen_image/pipeline_qwen_image_layered.py | 149 +----------------- 3 files changed, 24 insertions(+), 380 deletions(-) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 4241370bab..84dd8bb486 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -25,13 +25,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BaseQwenImagePipeline from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( @@ -213,10 +209,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class QwenImageEditPipeline( - nn.Module, - SupportImageInput, -): +class QwenImageEditPipeline(nn.Module, SupportImageInput, BaseQwenImagePipeline): def __init__( self, *, @@ -598,118 +591,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def diffuse( - self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - image_latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - ): - """Diffusion loop with optional image conditioning.""" - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - self._current_timestep = t - # broadcast to batch dimension and place on same device/dtype as latents - timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - - # Concatenate image latents with noise latents if available - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - self.transformer.do_true_cfg = do_true_cfg # used in teacache hook - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - else: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - local_pred = local_pred[:, : latents.size(1)] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - - else: - # Forward pass for positive prompt (or unconditional if no CFG) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - # Forward pass for negative prompt (CFG) - if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - return latents - def forward( self, req: OmniDiffusionRequest, @@ -886,7 +767,6 @@ def forward( negative_prompt_embeds, negative_prompt_embeds_mask, latents, - image_latents, img_shapes, txt_seq_lens, negative_txt_seq_lens, @@ -894,6 +774,12 @@ def forward( do_true_cfg, guidance, true_cfg_scale, + image_latents=image_latents, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 9402c1f7ce..31c2b4170c 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -23,13 +23,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BaseQwenImagePipeline from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import ( @@ -168,7 +164,7 @@ def post_process_func( return post_process_func -class QwenImageEditPlusPipeline(nn.Module, SupportImageInput): +class QwenImageEditPlusPipeline(nn.Module, SupportImageInput, BaseQwenImagePipeline): def __init__( self, *, @@ -530,116 +526,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def diffuse( - self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - image_latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - ): - """Diffusion loop with optional image conditioning.""" - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - self._current_timestep = t - # broadcast to batch dimension and place on same device/dtype as latents - timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - - # Concatenate image latents with noise latents if available - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - self.transformer.do_true_cfg = do_true_cfg - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - else: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - local_pred = local_pred[:, : latents.size(1)] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - - else: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - return latents - def forward( self, req: OmniDiffusionRequest, @@ -840,7 +726,6 @@ def forward( negative_prompt_embeds, negative_prompt_embeds_mask, latents, - image_latents, img_shapes, txt_seq_lens, negative_txt_seq_lens, @@ -848,6 +733,12 @@ def forward( do_true_cfg, guidance, true_cfg_scale, + image_latents=image_latents, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 1001e8a140..669548aad3 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -22,13 +22,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BaseQwenImagePipeline from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import ( AutoencoderKLQwenImage, @@ -193,7 +189,6 @@ def retrieve_latents( class QwenImageLayeredPipeline(nn.Module, SupportImageInput): color_format = "RGBA" - def __init__( self, *, @@ -553,138 +548,6 @@ def _unpack_latents(latents, height, width, layers, vae_scale_factor): return latents - def diffuse( - self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - image_latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - cfg_normalize, - is_rgb, - ): - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - self.transformer.do_true_cfg = do_true_cfg # used in teacache hook - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - additional_t_cond=is_rgb, - return_dict=False, - )[0] - local_pred = local_pred[:, : latents.size(1)] - else: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - additional_t_cond=is_rgb, - return_dict=False, - )[0] - local_pred = local_pred[:, : latents.size(1)] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - if cfg_normalize: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - noise_pred = comb_pred - - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - - else: - # Forward pass for positive prompt (or unconditional if no CFG) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - additional_t_cond=is_rgb, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - additional_t_cond=is_rgb, - return_dict=False, - )[0] - - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - if cfg_normalize: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - noise_pred = comb_pred - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - # no need callback now - return latents - @property def guidance_scale(self): return self._guidance_scale @@ -924,7 +787,6 @@ def forward( negative_prompt_embeds, negative_prompt_embeds_mask, latents, - image_latents, img_shapes, txt_seq_lens, negative_txt_seq_lens, @@ -932,8 +794,13 @@ def forward( do_true_cfg, guidance, true_cfg_scale, - cfg_normalize, - is_rgb, + image_latents=image_latents, + cfg_normalize=cfg_normalize, + additional_transformer_kwargs={ + "return_dict": False, + "additional_t_cond": is_rgb, + "attention_kwargs": self.attention_kwargs, + }, ) self._current_timestep = None From ecb9d34a7474689d6a84f8a20948521319aec810 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 16:45:00 +0800 Subject: [PATCH 03/67] flux2 pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../flux2_klein/pipeline_flux2_klein.py | 187 ++++++++++++------ 1 file changed, 123 insertions(+), 64 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 2eac7b2ece..644fbe7558 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -24,7 +24,6 @@ import numpy as np import PIL.Image import torch -import torch.nn as nn from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders.autoencoder_kl_flux2 import AutoencoderKLFlux2 from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps @@ -36,8 +35,14 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BasePipeline from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( Flux2Transformer2DModel, ) @@ -178,7 +183,7 @@ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: return float(mu) -class Flux2KleinPipeline(nn.Module, SupportImageInput): +class Flux2KleinPipeline(BasePipeline, SupportImageInput): """Flux2 klein pipeline for text-to-image generation.""" support_image_input = True @@ -640,9 +645,107 @@ def num_timesteps(self): def current_timestep(self): return self._current_timestep - @property - def interrupt(self): - return self._interrupt + def diffuse( + self, + latents, + latent_ids, + prompt_embeds, + text_ids, + negative_prompt_embeds, + negative_text_ids, + timesteps, + do_true_cfg, + guidance_scale, + image_latents=None, + image_latent_ids=None, + cfg_normalize=False, + ): + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + latent_ids: Position IDs for latents + prompt_embeds: Positive prompt embeddings + text_ids: Position IDs for positive text + negative_prompt_embeds: Negative prompt embeddings + negative_text_ids: Position IDs for negative text + timesteps: Diffusion timesteps + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + image_latents: Conditional image latents (default: None) + image_latent_ids: Position IDs for image latents (default: None) + cfg_normalize: Whether to normalize CFG output (default: False) + + Returns: + Denoised latents + """ + self.scheduler.set_begin_index(0) + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # Prepare latent model input + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + cfg_group = get_cfg_group() if cfg_parallel_ready else None + cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": None, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + } + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": None, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + } + + # For image conditioning, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_group, + cfg_rank, + cfg_normalize, + output_slice, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype and torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if cfg_group is not None: + cfg_group.broadcast(latents, src=0) + + return latents def forward( self, @@ -905,65 +1008,21 @@ def forward( ) self._num_timesteps = len(timesteps) - # 7. Denoising loop - # We set the index here to remove DtoH sync, helpful especially during compilation. - # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - latent_model_input = latents.to(self.transformer.dtype) - latent_image_ids = latent_ids - - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) - latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=None, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - noise_pred = noise_pred[:, : latents.size(1) :] - - if self.do_classifier_free_guidance: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=None, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype and torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + # 7. Denoising loop using diffuse method + latents = self.diffuse( + latents=latents, + latent_ids=latent_ids, + prompt_embeds=prompt_embeds, + text_ids=text_ids, + negative_prompt_embeds=negative_prompt_embeds if self.do_classifier_free_guidance else None, + negative_text_ids=negative_text_ids if self.do_classifier_free_guidance else None, + timesteps=timesteps, + do_true_cfg=self.do_classifier_free_guidance, + guidance_scale=guidance_scale, + image_latents=image_latents, + image_latent_ids=image_latent_ids, + cfg_normalize=False, # Flux2Klein doesn't use CFG normalization + ) self._current_timestep = None From a5641004b836bf4274affee871b7f1d0a7034c7f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 16:52:06 +0800 Subject: [PATCH 04/67] cfg normalize and longcat Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/base_pipeline.py | 20 +- .../longcat_image/pipeline_longcat_image.py | 173 ++++++++++++------ 2 files changed, 138 insertions(+), 55 deletions(-) diff --git a/vllm_omni/diffusion/models/base_pipeline.py b/vllm_omni/diffusion/models/base_pipeline.py index 24b89c2cea..df3519ae28 100644 --- a/vllm_omni/diffusion/models/base_pipeline.py +++ b/vllm_omni/diffusion/models/base_pipeline.py @@ -90,6 +90,22 @@ def predict_noise_maybe_with_cfg( pred = pred[:, :output_slice] return pred + def cfg_normalize_function(self, noise_pred, comb_pred): + """ + Normalize the combined noise prediction. + + Args: + noise_pred: positive noise prediction + comb_pred: combined noise prediction after CFG + + Returns: + Normalized noise prediction tensor + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + return noise_pred + def combine_cfg_noise(self, noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize=True): """ Combine conditional and unconditional noise predictions with CFG. @@ -106,9 +122,7 @@ def combine_cfg_noise(self, noise_pred, neg_noise_pred, true_cfg_scale, cfg_norm comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) if cfg_normalize: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) + noise_pred = self.cfg_normalize_function(noise_pred, comb_pred) else: noise_pred = comb_pred diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 8b616ec45f..a6104b5f44 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -8,6 +8,7 @@ import os import re from collections.abc import Iterable +from functools import partial from typing import Any import numpy as np @@ -17,14 +18,19 @@ from diffusers.pipelines.longcat_image.system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, SchedulerMixin from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BasePipeline from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import ( @@ -197,9 +203,7 @@ def get_prompt_language(prompt): return "en" -class LongCatImagePipeline( - nn.Module, -): +class LongCatImagePipeline(BasePipeline): def __init__( self, *, @@ -392,6 +396,105 @@ def _unpack_latents(latents, height, width, vae_scale_factor): def do_classifier_free_guidance(self): return self._guidance_scale > 1 + def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0): + """ + Normalize the combined noise prediction. + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = comb_pred * scale + return noise_pred + + def diffuse( + self, + latents, + latent_image_ids, + prompt_embeds, + text_ids, + negative_prompt_embeds, + negative_text_ids, + timesteps, + do_true_cfg, + guidance_scale, + cfg_normalize=True, + cfg_renorm_min=0.0, + ): + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + latent_image_ids: Position IDs for latents + prompt_embeds: Positive prompt embeddings + text_ids: Position IDs for positive text + negative_prompt_embeds: Negative prompt embeddings + negative_text_ids: Position IDs for negative text + timesteps: Diffusion timesteps + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + cfg_normalize: Whether to normalize CFG output with custom renorm (default: True) + cfg_renorm_min: Minimum value for CFG renormalization (default: 0.0) + + Returns: + Denoised latents + """ + guidance = None + + self.cfg_normalize_function = partial(self.cfg_normalize_function, cfg_renorm_min=cfg_renorm_min) + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + cfg_group = get_cfg_group() if cfg_parallel_ready else None + cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None + + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + } + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + } + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_group, + cfg_rank, + cfg_normalize=True, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if cfg_group is not None: + cfg_group.broadcast(latents, src=0) + + return latents + def prepare_latents( self, batch_size, @@ -594,9 +697,6 @@ def forward( self._num_timesteps = len(timesteps) - # handle guidance - guidance = None - if self._joint_attention_kwargs is None: self._joint_attention_kwargs = {} @@ -604,51 +704,20 @@ def forward( if self.do_classifier_free_guidance: negative_prompt_embeds = negative_prompt_embeds.to(device) - # 6. Denoising loop - for i, t in enumerate(timesteps): - if self._interrupt: - continue - - self._current_timestep = t - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - noise_pred_text = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) - - if enable_cfg_renorm: - cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) - noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) - noise_pred = noise_pred * scale - else: - noise_pred = noise_pred_text - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + # 6. Denoising loop using diffuse method + latents = self.diffuse( + latents=latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + text_ids=text_ids, + negative_prompt_embeds=negative_prompt_embeds if self.do_classifier_free_guidance else None, + negative_text_ids=negative_text_ids if self.do_classifier_free_guidance else None, + timesteps=timesteps, + do_true_cfg=self.do_classifier_free_guidance, + guidance_scale=self._guidance_scale, + cfg_normalize=enable_cfg_renorm, + cfg_renorm_min=cfg_renorm_min, + ) self._current_timestep = None From 5278e611e845a840a1484aba6d8211e5a7599632 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 16:53:54 +0800 Subject: [PATCH 05/67] longcat_edit pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../pipeline_longcat_image_edit.py | 164 +++++++++++++----- 1 file changed, 120 insertions(+), 44 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index f2c3fd648e..b206c1a3e8 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -16,7 +16,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import ( AutoTokenizer, Qwen2_5_VLForConditionalGeneration, @@ -26,8 +25,14 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BasePipeline from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import ( LongCatImageTransformer2DModel, @@ -215,7 +220,7 @@ def split_quotation(prompt, quote_pairs=None): return result -class LongCatImageEditPipeline(nn.Module, SupportImageInput): +class LongCatImageEditPipeline(BasePipeline, SupportImageInput): def __init__( self, *, @@ -396,6 +401,104 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents + def diffuse( + self, + latents, + image_latents, + latent_image_ids, + prompt_embeds, + text_ids, + negative_prompt_embeds, + negative_text_ids, + timesteps, + do_true_cfg, + guidance_scale, + image_seq_len, + cfg_normalize=False, + ): + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + image_latents: Conditional image latents + latent_image_ids: Position IDs for latents and image + prompt_embeds: Positive prompt embeddings + text_ids: Position IDs for positive text + negative_prompt_embeds: Negative prompt embeddings + negative_text_ids: Position IDs for negative text + timesteps: Diffusion timesteps + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + image_seq_len: Sequence length of image latents for slicing + cfg_normalize: Whether to normalize CFG output (default: False) + + Returns: + Denoised latents + """ + guidance = None + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + cfg_group = get_cfg_group() if cfg_parallel_ready else None + cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + } + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + } + + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = image_seq_len if image_latents is not None else None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_group, + cfg_rank, + cfg_normalize, + output_slice, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if cfg_group is not None: + cfg_group.broadcast(latents, src=0) + + return latents + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ @@ -637,53 +740,26 @@ def forward( ) self._num_timesteps = len(timesteps) - guidance = None - if image is not None: latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0) else: latent_image_ids = latents_ids - for i, t in enumerate(timesteps): - self._current_timestep = t - - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - - noise_pred_text = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred_text = noise_pred_text[:, :image_seq_len] - if guidance_scale > 1: - noise_pred_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred_uncond = noise_pred_uncond[:, :image_seq_len] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred = noise_pred_text - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + # Denoising loop using diffuse method + latents = self.diffuse( + latents=latents, + image_latents=image_latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + text_ids=text_ids, + negative_prompt_embeds=negative_prompt_embeds if guidance_scale > 1 else None, + negative_text_ids=negative_text_ids if guidance_scale > 1 else None, + timesteps=timesteps, + do_true_cfg=guidance_scale > 1, + guidance_scale=guidance_scale, + image_seq_len=image_seq_len, + cfg_normalize=False, + ) self._current_timestep = None From 01a03f3833f89fcfec90ea8148c989d61c7ba6a5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:08:48 +0800 Subject: [PATCH 06/67] ovis pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../models/ovis_image/pipeline_ovis_image.py | 134 ++++++++++-------- 1 file changed, 76 insertions(+), 58 deletions(-) diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index a87410a6c9..71109215e7 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -23,7 +23,6 @@ import numpy as np import torch -import torch.nn as nn from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( @@ -35,8 +34,14 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BasePipeline from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import OvisImageTransformer2DModel from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific @@ -139,9 +144,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class OvisImagePipeline( - nn.Module, -): +class OvisImagePipeline(BasePipeline): def __init__( self, *, @@ -431,49 +434,77 @@ def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): ) return timesteps, num_inference_steps - def denoising( + def diffuse( self, latents, timesteps, prompt_embeds, negative_prompt_embeds, - guidance_scale, - do_classifier_free_guidance, - callback_on_step_end, - callback_on_step_end_tensor_inputs, text_ids, negative_text_ids, latent_image_ids, + do_true_cfg, + guidance_scale, + cfg_normalize=False, ): + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + timesteps: Diffusion timesteps + prompt_embeds: Positive prompt embeddings + negative_prompt_embeds: Negative prompt embeddings + text_ids: Position IDs for positive text + negative_text_ids: Position IDs for negative text + latent_image_ids: Position IDs for image latents + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + cfg_normalize: Whether to normalize CFG output (default: False) + + Returns: + Denoised latents + """ self.scheduler.set_begin_index(0) + for i, t in enumerate(timesteps): - if self._interrupt: + if self.interrupt: break self._current_timestep = t - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - - if do_classifier_free_guidance: - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + cfg_group = get_cfg_group() if cfg_parallel_ready else None + cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None + + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_group, + cfg_rank, + cfg_normalize, + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -481,18 +512,10 @@ def denoising( if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - # Not used in this pipeline - # if callback_on_step_end is not None: - # callback_kwargs = {} - # for k in callback_on_step_end_tensor_inputs: - # callback_kwargs[k] = locals()[k] - # callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - # latents = callback_outputs.pop("latents", latents) - # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + if cfg_group is not None: + cfg_group.broadcast(latents, src=0) return latents @@ -704,23 +727,18 @@ def forward( if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} - # 6. Denoising loop - - # We set the index here to remove DtoH sync, helpful especially during compilation - # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 - - latents = self.denoising( - latents, - timesteps, - prompt_embeds, - negative_prompt_embeds, - guidance_scale, - do_classifier_free_guidance, - callback_on_step_end, - callback_on_step_end_tensor_inputs, - text_ids, - negative_text_ids, - latent_image_ids, + # 6. Denoising loop using diffuse method + latents = self.diffuse( + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds if do_classifier_free_guidance else None, + text_ids=text_ids, + negative_text_ids=negative_text_ids if do_classifier_free_guidance else None, + latent_image_ids=latent_image_ids, + do_true_cfg=do_classifier_free_guidance, + guidance_scale=guidance_scale, + cfg_normalize=False, ) self._current_timestep = None From 607aef876f5ba62552170a4aba41a9208d2a1e48 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:12:48 +0800 Subject: [PATCH 07/67] sd3 pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/models/sd3/pipeline_sd3.py | 99 +++++++++++++------ 1 file changed, 69 insertions(+), 30 deletions(-) diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 34a0eb6c14..5fa46fb398 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -11,20 +11,25 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.base_pipeline import BasePipeline from vllm_omni.diffusion.models.sd3.sd3_transformer import ( SD3Transformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.model_executor.model_loader.weight_utils import ( - download_weights_from_hf_specific, -) logger = logging.getLogger(__name__) @@ -126,9 +131,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline( - nn.Module, -): +class StableDiffusion3Pipeline(BasePipeline): def __init__( self, *, @@ -498,15 +501,35 @@ def interrupt(self): def diffuse( self, + latents, + timesteps, prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds, - latents, - timesteps, - do_cfg, + do_true_cfg, + guidance_scale, + cfg_normalize=False, ): + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + timesteps: Diffusion timesteps + prompt_embeds: Positive prompt embeddings + pooled_prompt_embeds: Pooled positive prompt embeddings + negative_prompt_embeds: Negative prompt embeddings + negative_pooled_prompt_embeds: Pooled negative prompt embeddings + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + cfg_normalize: Whether to normalize CFG output (default: False) + + Returns: + Denoised latents + """ self.scheduler.set_begin_index(0) + for _, t in enumerate(timesteps): if self.interrupt: continue @@ -515,29 +538,42 @@ def diffuse( # Broadcast timestep to match batch size timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - transformer_kwargs = { + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + cfg_group = get_cfg_group() if cfg_parallel_ready else None + cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None + + positive_kwargs = { "hidden_states": latents, "timestep": timestep, "encoder_hidden_states": prompt_embeds, "pooled_projections": pooled_prompt_embeds, "return_dict": False, } + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "pooled_projections": negative_pooled_prompt_embeds, + "return_dict": False, + } - noise_pred = self.transformer(**transformer_kwargs)[0] - - if do_cfg: - neg_transformer_kwargs = { - "hidden_states": latents, - "timestep": timestep, - "encoder_hidden_states": negative_prompt_embeds, - "pooled_projections": negative_pooled_prompt_embeds, - "return_dict": False, - } + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_group, + cfg_rank, + cfg_normalize, + ) - neg_noise_pred = self.transformer(**neg_transformer_kwargs)[0] - noise_pred = neg_noise_pred + self.guidance_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if cfg_group is not None: + cfg_group.broadcast(latents, src=0) + return latents def forward( @@ -644,14 +680,17 @@ def forward( timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) self._num_timesteps = len(timesteps) + # Denoising loop using diffuse method latents = self.diffuse( - prompt_embeds, - pooled_prompt_embeds, - negative_prompt_embeds, - negative_pooled_prompt_embeds, - latents, - timesteps, - do_cfg, + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds if do_cfg else None, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds if do_cfg else None, + do_true_cfg=do_cfg, + guidance_scale=self.guidance_scale, + cfg_normalize=False, ) self._current_timestep = None From 788d30250a4d7593667f377ebbcfb07b8e697eba Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:33:37 +0800 Subject: [PATCH 08/67] updates Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/base_pipeline.py | 7 +++++++ .../models/qwen_image/pipeline_qwen_image_edit.py | 3 +-- .../models/qwen_image/pipeline_qwen_image_edit_plus.py | 3 +-- .../models/qwen_image/pipeline_qwen_image_layered.py | 1 - 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/models/base_pipeline.py b/vllm_omni/diffusion/models/base_pipeline.py index df3519ae28..75088a3c24 100644 --- a/vllm_omni/diffusion/models/base_pipeline.py +++ b/vllm_omni/diffusion/models/base_pipeline.py @@ -23,6 +23,13 @@ class BasePipeline(nn.Module): classifier-free guidance logic. """ + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + def predict_noise_maybe_with_cfg( self, do_true_cfg, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 84dd8bb486..1fe9fdee79 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -20,7 +20,6 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader @@ -209,7 +208,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class QwenImageEditPipeline(nn.Module, SupportImageInput, BaseQwenImagePipeline): +class QwenImageEditPipeline(SupportImageInput, BaseQwenImagePipeline): def __init__( self, *, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 31c2b4170c..d78834f05e 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -18,7 +18,6 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader @@ -164,7 +163,7 @@ def post_process_func( return post_process_func -class QwenImageEditPlusPipeline(nn.Module, SupportImageInput, BaseQwenImagePipeline): +class QwenImageEditPlusPipeline(SupportImageInput, BaseQwenImagePipeline): def __init__( self, *, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 669548aad3..328742b198 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -17,7 +17,6 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader From d5a618268248469491e583aae62e7bfc389ac2c4 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:47:34 +0800 Subject: [PATCH 09/67] cfg mixin Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/base_pipeline.py | 16 +++++----------- .../models/qwen_image/pipeline_qwen_image.py | 5 +++-- .../qwen_image/pipeline_qwen_image_edit.py | 5 +++-- .../qwen_image/pipeline_qwen_image_edit_plus.py | 5 +++-- .../qwen_image/pipeline_qwen_image_layered.py | 5 +++-- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/vllm_omni/diffusion/models/base_pipeline.py b/vllm_omni/diffusion/models/base_pipeline.py index 75088a3c24..39a139a509 100644 --- a/vllm_omni/diffusion/models/base_pipeline.py +++ b/vllm_omni/diffusion/models/base_pipeline.py @@ -5,8 +5,9 @@ Base pipeline class for Qwen Image models with shared CFG functionality. """ +from abc import ABCMeta + import torch -from torch import nn from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, @@ -15,21 +16,14 @@ ) -class BasePipeline(nn.Module): +class CFGParallelMixin(metaclass=ABCMeta): """ - Base class for Diffusion pipelines providing shared CFG methods. + Base Mixi class for Diffusion pipelines providing shared CFG methods. All pipelines should inherit from this class to reuse classifier-free guidance logic. """ - def __init__( - self, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - def predict_noise_maybe_with_cfg( self, do_true_cfg, @@ -160,7 +154,7 @@ def interrupt(self): return getattr(self, "_interrupt", False) -class BaseQwenImagePipeline(BasePipeline): +class QwenImageCFGParallelMixin(CFGParallelMixin): """ Base class for Qwen Image pipelines providing shared CFG methods. """ diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 054783b3fe..20ae9096de 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -20,13 +20,14 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BaseQwenImagePipeline +from vllm_omni.diffusion.models.base_pipeline import QwenImageCFGParallelMixin from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) @@ -234,7 +235,7 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) -class QwenImagePipeline(BaseQwenImagePipeline): +class QwenImagePipeline(nn.Module, QwenImageCFGParallelMixin): def __init__( self, *, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 1fe9fdee79..422a4cf09b 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -20,13 +20,14 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BaseQwenImagePipeline +from vllm_omni.diffusion.models.base_pipeline import QwenImageCFGParallelMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( @@ -208,7 +209,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class QwenImageEditPipeline(SupportImageInput, BaseQwenImagePipeline): +class QwenImageEditPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): def __init__( self, *, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index d78834f05e..26a27f12e4 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -18,13 +18,14 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BaseQwenImagePipeline +from vllm_omni.diffusion.models.base_pipeline import QwenImageCFGParallelMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import ( @@ -163,7 +164,7 @@ def post_process_func( return post_process_func -class QwenImageEditPlusPipeline(SupportImageInput, BaseQwenImagePipeline): +class QwenImageEditPlusPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): def __init__( self, *, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 328742b198..c0555dc8c8 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -17,13 +17,14 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BaseQwenImagePipeline +from vllm_omni.diffusion.models.base_pipeline import QwenImageCFGParallelMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import ( AutoencoderKLQwenImage, @@ -186,7 +187,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class QwenImageLayeredPipeline(nn.Module, SupportImageInput): +class QwenImageLayeredPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): color_format = "RGBA" def __init__( self, From 5253b8d285a650b2738fa8effa13d9477a628238 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 18:02:00 +0800 Subject: [PATCH 10/67] correct latent step Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/base_pipeline.py | 20 +++++++++---- .../flux2_klein/pipeline_flux2_klein.py | 28 +++++++++++++------ .../longcat_image/pipeline_longcat_image.py | 5 ++-- .../pipeline_longcat_image_edit.py | 28 +++++++++++++------ .../models/ovis_image/pipeline_ovis_image.py | 28 +++++++++++++------ .../diffusion/models/sd3/pipeline_sd3.py | 12 +++++--- 6 files changed, 84 insertions(+), 37 deletions(-) diff --git a/vllm_omni/diffusion/models/base_pipeline.py b/vllm_omni/diffusion/models/base_pipeline.py index 39a139a509..bff7077c86 100644 --- a/vllm_omni/diffusion/models/base_pipeline.py +++ b/vllm_omni/diffusion/models/base_pipeline.py @@ -18,7 +18,7 @@ class CFGParallelMixin(metaclass=ABCMeta): """ - Base Mixi class for Diffusion pipelines providing shared CFG methods. + Base Mixin class for Diffusion pipelines providing shared CFG methods. All pipelines should inherit from this class to reuse classifier-free guidance logic. @@ -69,7 +69,8 @@ def predict_noise_maybe_with_cfg( noise_pred = gathered[0] neg_noise_pred = gathered[1] noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) - return noise_pred + + return noise_pred else: # Sequential CFG: compute both positive and negative positive_noise_pred = self.predict_noise(**positive_kwargs) @@ -156,7 +157,7 @@ def interrupt(self): class QwenImageCFGParallelMixin(CFGParallelMixin): """ - Base class for Qwen Image pipelines providing shared CFG methods. + Base Mixin class for Qwen Image pipelines providing shared CFG methods. """ def diffuse( @@ -258,8 +259,17 @@ def diffuse( ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if cfg_group is not None: - cfg_group.broadcast(latents, src=0) + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) + cfg_group.broadcast(latents, src=0) + else: + latents = self.scheduler_step(noise_pred, t, latents) return latents + + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + return self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 644fbe7558..b1bc6d0d53 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -30,6 +30,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader @@ -42,7 +43,7 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BasePipeline +from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( Flux2Transformer2DModel, ) @@ -183,7 +184,7 @@ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: return float(mu) -class Flux2KleinPipeline(BasePipeline, SupportImageInput): +class Flux2KleinPipeline(nn.Module, CFGParallelMixin, SupportImageInput): """Flux2 klein pipeline for text-to-image generation.""" support_image_input = True @@ -736,17 +737,28 @@ def diffuse( ) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if cfg_group is not None: + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) + cfg_group.broadcast(latents, src=0) + else: + latents = self.scheduler_step(noise_pred, t, latents) - if latents.dtype != latents_dtype and torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) + return latents - if cfg_group is not None: - cfg_group.broadcast(latents, src=0) + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) return latents + @torch.no_grad() def forward( self, req: OmniDiffusionRequest, diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index a6104b5f44..063e847afc 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -18,6 +18,7 @@ from diffusers.pipelines.longcat_image.system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, SchedulerMixin from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader @@ -30,7 +31,7 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BasePipeline +from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import ( @@ -203,7 +204,7 @@ def get_prompt_language(prompt): return "en" -class LongCatImagePipeline(BasePipeline): +class LongCatImagePipeline(nn.Module, CFGParallelMixin): def __init__( self, *, diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index b206c1a3e8..1a87f5b38f 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -16,6 +16,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import ( AutoTokenizer, Qwen2_5_VLForConditionalGeneration, @@ -32,7 +33,7 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BasePipeline +from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import ( LongCatImageTransformer2DModel, @@ -220,7 +221,7 @@ def split_quotation(prompt, quote_pairs=None): return result -class LongCatImageEditPipeline(BasePipeline, SupportImageInput): +class LongCatImageEditPipeline(nn.Module, CFGParallelMixin, SupportImageInput): def __init__( self, *, @@ -487,16 +488,25 @@ def diffuse( ) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if cfg_group is not None: + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) + cfg_group.broadcast(latents, src=0) + else: + latents = self.scheduler_step(noise_pred, t, latents) - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) + return latents - if cfg_group is not None: - cfg_group.broadcast(latents, src=0) + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) return latents def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 71109215e7..46dee0d8e2 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -29,6 +29,7 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import Qwen2TokenizerFast, Qwen3Model from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader @@ -41,7 +42,7 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BasePipeline +from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import OvisImageTransformer2DModel from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific @@ -144,7 +145,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class OvisImagePipeline(BasePipeline): +class OvisImagePipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -507,16 +508,25 @@ def diffuse( ) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if cfg_group is not None: + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) + cfg_group.broadcast(latents, src=0) + else: + latents = self.scheduler_step(noise_pred, t, latents) - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) + return latents - if cfg_group is not None: - cfg_group.broadcast(latents, src=0) + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) return latents @property diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 5fa46fb398..478cf6bb33 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -11,6 +11,7 @@ FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer from vllm.model_executor.models.utils import AutoWeightsLoader @@ -25,7 +26,7 @@ download_weights_from_hf_specific, ) from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import BasePipeline +from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.sd3.sd3_transformer import ( SD3Transformer2DModel, ) @@ -131,7 +132,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(BasePipeline): +class StableDiffusion3Pipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -569,10 +570,13 @@ def diffuse( ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if cfg_group is not None: - cfg_group.broadcast(latents, src=0) + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) + cfg_group.broadcast(latents, src=0) + else: + latents = self.scheduler_step(noise_pred, t, latents) return latents From 783bce47a08a54a9cc3b69cb8e7bf43dccc8fb54 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 20:11:09 +0800 Subject: [PATCH 11/67] cfg broadcast correct Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/base_pipeline.py | 2 +- .../flux2_klein/pipeline_flux2_klein.py | 2 +- .../longcat_image/pipeline_longcat_image.py | 23 +++++++++++++------ .../pipeline_longcat_image_edit.py | 2 +- .../models/ovis_image/pipeline_ovis_image.py | 2 +- .../diffusion/models/sd3/pipeline_sd3.py | 2 +- 6 files changed, 21 insertions(+), 12 deletions(-) diff --git a/vllm_omni/diffusion/models/base_pipeline.py b/vllm_omni/diffusion/models/base_pipeline.py index bff7077c86..da4204555e 100644 --- a/vllm_omni/diffusion/models/base_pipeline.py +++ b/vllm_omni/diffusion/models/base_pipeline.py @@ -262,7 +262,7 @@ def diffuse( if cfg_group is not None: if cfg_rank == 0: latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) + cfg_group.broadcast(latents, src=0) else: latents = self.scheduler_step(noise_pred, t, latents) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index b1bc6d0d53..de005a2293 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -740,7 +740,7 @@ def diffuse( if cfg_group is not None: if cfg_rank == 0: latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) + cfg_group.broadcast(latents, src=0) else: latents = self.scheduler_step(noise_pred, t, latents) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 063e847afc..fc2becd949 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -484,16 +484,25 @@ def diffuse( ) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - if cfg_group is not None: + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) cfg_group.broadcast(latents, src=0) + else: + latents = self.scheduler_step(noise_pred, t, latents) + + return latents + + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) return latents def prepare_latents( diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 1a87f5b38f..beb40ac491 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -491,7 +491,7 @@ def diffuse( if cfg_group is not None: if cfg_rank == 0: latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) + cfg_group.broadcast(latents, src=0) else: latents = self.scheduler_step(noise_pred, t, latents) diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 46dee0d8e2..47ebec82a9 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -511,7 +511,7 @@ def diffuse( if cfg_group is not None: if cfg_rank == 0: latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) + cfg_group.broadcast(latents, src=0) else: latents = self.scheduler_step(noise_pred, t, latents) diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 478cf6bb33..10c321cf03 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -574,7 +574,7 @@ def diffuse( if cfg_group is not None: if cfg_rank == 0: latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) + cfg_group.broadcast(latents, src=0) else: latents = self.scheduler_step(noise_pred, t, latents) From 4158cdf8b5af7e2ab227c3076b667963632d11ef Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 19 Jan 2026 20:34:00 +0800 Subject: [PATCH 12/67] cfg parallel new name Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../{models/base_pipeline.py => distributed/cfg_parallel.py} | 0 vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py | 2 +- .../diffusion/models/longcat_image/pipeline_longcat_image.py | 2 +- .../models/longcat_image/pipeline_longcat_image_edit.py | 2 +- vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py | 2 +- vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py | 2 +- .../diffusion/models/qwen_image/pipeline_qwen_image_edit.py | 2 +- .../models/qwen_image/pipeline_qwen_image_edit_plus.py | 2 +- .../diffusion/models/qwen_image/pipeline_qwen_image_layered.py | 2 +- vllm_omni/diffusion/models/sd3/pipeline_sd3.py | 2 +- 10 files changed, 9 insertions(+), 9 deletions(-) rename vllm_omni/diffusion/{models/base_pipeline.py => distributed/cfg_parallel.py} (100%) diff --git a/vllm_omni/diffusion/models/base_pipeline.py b/vllm_omni/diffusion/distributed/cfg_parallel.py similarity index 100% rename from vllm_omni/diffusion/models/base_pipeline.py rename to vllm_omni/diffusion/distributed/cfg_parallel.py diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index de005a2293..928b5de3da 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -36,6 +36,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, @@ -43,7 +44,6 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( Flux2Transformer2DModel, ) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index fc2becd949..466bf9ab10 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -24,6 +24,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, @@ -31,7 +32,6 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import ( diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index beb40ac491..52634015b9 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -26,6 +26,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, @@ -33,7 +34,6 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import ( LongCatImageTransformer2DModel, diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 47ebec82a9..4493e31c6c 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -35,6 +35,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, @@ -42,7 +43,6 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import OvisImageTransformer2DModel from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 20ae9096de..328bc53920 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -25,9 +25,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import QwenImageCFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import QwenImageCFGParallelMixin from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 422a4cf09b..1941748810 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -25,9 +25,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import QwenImageCFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import QwenImageCFGParallelMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 26a27f12e4..6661d3fd54 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -23,9 +23,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import QwenImageCFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import QwenImageCFGParallelMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import ( diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index c0555dc8c8..d712d83a34 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -22,9 +22,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import QwenImageCFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import QwenImageCFGParallelMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import ( AutoencoderKLQwenImage, diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 10c321cf03..45db151e62 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -16,6 +16,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, @@ -26,7 +27,6 @@ download_weights_from_hf_specific, ) from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.base_pipeline import CFGParallelMixin from vllm_omni.diffusion.models.sd3.sd3_transformer import ( SD3Transformer2DModel, ) From 2224e80094c7f4a71dc797f8c239761d7b81bfe6 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:16:32 +0800 Subject: [PATCH 13/67] update cfg parallel logic Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/distributed/cfg_parallel.py | 90 +++++++++++++------ .../flux2_klein/pipeline_flux2_klein.py | 22 +---- .../longcat_image/pipeline_longcat_image.py | 22 +---- .../pipeline_longcat_image_edit.py | 22 +---- .../models/ovis_image/pipeline_ovis_image.py | 22 +---- .../diffusion/models/sd3/pipeline_sd3.py | 23 +---- 6 files changed, 77 insertions(+), 124 deletions(-) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index da4204555e..6aad6e2c6e 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -30,8 +30,6 @@ def predict_noise_maybe_with_cfg( true_cfg_scale, positive_kwargs, negative_kwargs, - cfg_group=None, - cfg_rank=None, cfg_normalize=True, output_slice=None, ): @@ -43,18 +41,21 @@ def predict_noise_maybe_with_cfg( true_cfg_scale: CFG scale factor positive_kwargs: Kwargs for positive/conditional prediction negative_kwargs: Kwargs for negative/unconditional prediction - cfg_group: Communication group for CFG parallelism - cfg_rank: Rank in CFG parallel group cfg_normalize: Whether to normalize CFG output (default: True) output_slice: If set, slice output to [:, :output_slice] for image editing Returns: - Predicted noise tensor + Predicted noise tensor (only valid on rank 0 in CFG parallel mode) """ if do_true_cfg: - if cfg_group is not None: + # Automatically detect CFG parallel configuration + cfg_parallel_ready = get_classifier_free_guidance_world_size() > 1 + + if cfg_parallel_ready: # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - assert cfg_rank is not None, "cfg_rank must be provided if cfg_group is provided" + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + if cfg_rank == 0: local_pred = self.predict_noise(**positive_kwargs) else: @@ -69,8 +70,9 @@ def predict_noise_maybe_with_cfg( noise_pred = gathered[0] neg_noise_pred = gathered[1] noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) - return noise_pred + else: + return None else: # Sequential CFG: compute both positive and negative positive_noise_pred = self.predict_noise(**positive_kwargs) @@ -154,6 +156,55 @@ def interrupt(self): """Property to check if diffusion should be interrupted.""" return getattr(self, "_interrupt", False) + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + + Args: + noise_pred: Predicted noise + t: Current timestep + latents: Current latents + + Returns: + Updated latents after scheduler step + """ + return self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + def scheduler_step_maybe_with_cfg(self, noise_pred, t, latents, do_true_cfg): + """ + Step the scheduler with (maybe) automatic CFG parallel synchronization. + + In CFG parallel mode, only rank 0 computes the scheduler step, + then broadcasts the result to other ranks. + + Args: + noise_pred: Predicted noise (only valid on rank 0 in CFG parallel) + t: Current timestep + latents: Current latents + do_true_cfg: Whether CFG is enabled + + Returns: + Updated latents (synchronized across all CFG ranks) + """ + # Automatically detect CFG parallel configuration + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + + if cfg_parallel_ready: + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + # Only rank 0 computes the scheduler step + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) + + # Broadcast the updated latents to all ranks + cfg_group.broadcast(latents, src=0) + else: + # No CFG parallel: directly compute scheduler step + latents = self.scheduler_step(noise_pred, t, latents) + + return latents + class QwenImageCFGParallelMixin(CFGParallelMixin): """ @@ -218,11 +269,6 @@ def diffuse( if image_latents is not None: latent_model_input = torch.cat([latents, image_latents], dim=1) - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - cfg_group = get_cfg_group() if cfg_parallel_ready else None - cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None - positive_kwargs = { "hidden_states": latent_model_input, "timestep": timestep / 1000, @@ -247,29 +293,17 @@ def diffuse( # For editing pipelines, we need to slice the output to remove condition latents output_slice = latents.size(1) if image_latents is not None else None + # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( do_true_cfg, true_cfg_scale, positive_kwargs, negative_kwargs, - cfg_group, - cfg_rank, cfg_normalize, output_slice, ) - # compute the previous noisy sample x_t -> x_t-1 - if cfg_group is not None: - if cfg_rank == 0: - latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) - else: - latents = self.scheduler_step(noise_pred, t, latents) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) return latents - - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - return self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 928b5de3da..84d248c118 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -37,11 +37,6 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( @@ -698,11 +693,6 @@ def diffuse( latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - cfg_group = get_cfg_group() if cfg_parallel_ready else None - cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None - positive_kwargs = { "hidden_states": latent_model_input, "timestep": timestep / 1000, @@ -725,24 +715,18 @@ def diffuse( # For image conditioning, we need to slice the output to remove condition latents output_slice = latents.size(1) if image_latents is not None else None + # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( do_true_cfg, guidance_scale, positive_kwargs, negative_kwargs, - cfg_group, - cfg_rank, cfg_normalize, output_slice, ) - # compute the previous noisy sample x_t -> x_t-1 - if cfg_group is not None: - if cfg_rank == 0: - latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) - else: - latents = self.scheduler_step(noise_pred, t, latents) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) return latents diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 466bf9ab10..da0c7da9be 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -25,11 +25,6 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel @@ -451,11 +446,6 @@ def diffuse( self._current_timestep = t timestep = t.expand(latents.shape[0]).to(latents.dtype) - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - cfg_group = get_cfg_group() if cfg_parallel_ready else None - cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None - positive_kwargs = { "hidden_states": latents, "timestep": timestep / 1000, @@ -473,23 +463,17 @@ def diffuse( "img_ids": latent_image_ids, } + # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( do_true_cfg, guidance_scale, positive_kwargs, negative_kwargs, - cfg_group, - cfg_rank, cfg_normalize=True, ) - # compute the previous noisy sample x_t -> x_t-1 - if cfg_group is not None: - if cfg_rank == 0: - latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) - else: - latents = self.scheduler_step(noise_pred, t, latents) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) return latents diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 52634015b9..7bebd17839 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -27,11 +27,6 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -451,11 +446,6 @@ def diffuse( timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - cfg_group = get_cfg_group() if cfg_parallel_ready else None - cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None - positive_kwargs = { "hidden_states": latent_model_input, "timestep": timestep / 1000, @@ -476,24 +466,18 @@ def diffuse( # For editing pipelines, we need to slice the output to remove condition latents output_slice = image_seq_len if image_latents is not None else None + # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( do_true_cfg, guidance_scale, positive_kwargs, negative_kwargs, - cfg_group, - cfg_rank, cfg_normalize, output_slice, ) - # compute the previous noisy sample x_t -> x_t-1 - if cfg_group is not None: - if cfg_rank == 0: - latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) - else: - latents = self.scheduler_step(noise_pred, t, latents) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) return latents diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 4493e31c6c..06d22fea13 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -36,11 +36,6 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import OvisImageTransformer2DModel @@ -475,11 +470,6 @@ def diffuse( self._current_timestep = t timestep = t.expand(latents.shape[0]).to(latents.dtype) - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - cfg_group = get_cfg_group() if cfg_parallel_ready else None - cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None - positive_kwargs = { "hidden_states": latents, "timestep": timestep / 1000, @@ -497,23 +487,17 @@ def diffuse( "return_dict": False, } + # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( do_true_cfg, guidance_scale, positive_kwargs, negative_kwargs, - cfg_group, - cfg_rank, cfg_normalize, ) - # compute the previous noisy sample x_t -> x_t-1 - if cfg_group is not None: - if cfg_rank == 0: - latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) - else: - latents = self.scheduler_step(noise_pred, t, latents) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) return latents diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 45db151e62..edeb6d1661 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -17,11 +17,6 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, @@ -539,11 +534,6 @@ def diffuse( # Broadcast timestep to match batch size timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - cfg_group = get_cfg_group() if cfg_parallel_ready else None - cfg_rank = get_classifier_free_guidance_rank() if cfg_parallel_ready else None - positive_kwargs = { "hidden_states": latents, "timestep": timestep, @@ -559,24 +549,17 @@ def diffuse( "return_dict": False, } + # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( do_true_cfg, guidance_scale, positive_kwargs, negative_kwargs, - cfg_group, - cfg_rank, cfg_normalize, ) - # compute the previous noisy sample x_t -> x_t-1 - - if cfg_group is not None: - if cfg_rank == 0: - latents = self.scheduler_step(noise_pred, t, latents) - cfg_group.broadcast(latents, src=0) - else: - latents = self.scheduler_step(noise_pred, t, latents) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) return latents From 5b4a378763b468ad4ccd2692f772a48f9f95c88b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:29:11 +0800 Subject: [PATCH 14/67] fix flux2 Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/cfg_parallel.py | 5 ----- .../diffusion/models/flux2_klein/pipeline_flux2_klein.py | 4 ++++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 6aad6e2c6e..71a0eb1f68 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -151,11 +151,6 @@ def diffuse( """ raise NotImplementedError("Subclasses must implement diffuse") - @property - def interrupt(self): - """Property to check if diffusion should be interrupted.""" - return getattr(self, "_interrupt", False) - def scheduler_step(self, noise_pred, t, latents): """ Step the scheduler. diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 84d248c118..83f5f7c32a 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -641,6 +641,10 @@ def num_timesteps(self): def current_timestep(self): return self._current_timestep + @property + def interrupt(self): + return self._interrupt + def diffuse( self, latents, From 333015370db0b9e2b10140452cc16f4165eb12d8 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:32:02 +0800 Subject: [PATCH 15/67] fix flux2 Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 83f5f7c32a..45db72a494 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -705,6 +705,7 @@ def diffuse( "txt_ids": text_ids, "img_ids": latent_image_ids, "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, } negative_kwargs = { "hidden_states": latent_model_input, @@ -714,6 +715,7 @@ def diffuse( "txt_ids": negative_text_ids, "img_ids": latent_image_ids, "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, } # For image conditioning, we need to slice the output to remove condition latents From ef6122f9ec277a2cd68a7586d83db0346ec378b7 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:44:27 +0800 Subject: [PATCH 16/67] fix longcat image Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/models/longcat_image/pipeline_longcat_image.py | 2 ++ .../models/longcat_image/pipeline_longcat_image_edit.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index da0c7da9be..e952cd3918 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -453,6 +453,7 @@ def diffuse( "encoder_hidden_states": prompt_embeds, "txt_ids": text_ids, "img_ids": latent_image_ids, + "return_dict": False, } negative_kwargs = { "hidden_states": latents, @@ -461,6 +462,7 @@ def diffuse( "encoder_hidden_states": negative_prompt_embeds, "txt_ids": negative_text_ids, "img_ids": latent_image_ids, + "return_dict": False, } # Predict noise with automatic CFG parallel handling diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 7bebd17839..24909e72b6 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -453,6 +453,7 @@ def diffuse( "encoder_hidden_states": prompt_embeds, "txt_ids": text_ids, "img_ids": latent_image_ids, + "return_dict": False, } negative_kwargs = { "hidden_states": latent_model_input, @@ -461,6 +462,7 @@ def diffuse( "encoder_hidden_states": negative_prompt_embeds, "txt_ids": negative_text_ids, "img_ids": latent_image_ids, + "return_dict": False, } # For editing pipelines, we need to slice the output to remove condition latents From 222d5fab804d173c5e31f1ddadf565e35a6ce8a7 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:57:05 +0800 Subject: [PATCH 17/67] doc Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/distributed/cfg_parallel.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 71a0eb1f68..55bf6702e5 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -110,7 +110,7 @@ def cfg_normalize_function(self, noise_pred, comb_pred): noise_pred = comb_pred * (cond_norm / noise_norm) return noise_pred - def combine_cfg_noise(self, noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize=True): + def combine_cfg_noise(self, noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize=False): """ Combine conditional and unconditional noise predictions with CFG. @@ -118,7 +118,7 @@ def combine_cfg_noise(self, noise_pred, neg_noise_pred, true_cfg_scale, cfg_norm noise_pred: Conditional noise prediction neg_noise_pred: Unconditional noise prediction true_cfg_scale: CFG scale factor - cfg_normalize: Whether to normalize the combined prediction (default: True) + cfg_normalize: Whether to normalize the combined prediction (default: False) Returns: Combined noise prediction tensor @@ -148,6 +148,33 @@ def diffuse( ): """ Diffusion loop with optional classifier-free guidance. + + Subclasses MUST implement this method to define the complete + diffusion/denoising loop for their specific model. + + Typical implementation pattern: + ```python + def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds, ...): + for t in timesteps: + # Prepare kwargs for positive and negative predictions + positive_kwargs = {...} + negative_kwargs = {...} + + # Predict noise with automatic CFG handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=self.guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + ) + + # Step scheduler with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg=True + ) + + return latents + ``` """ raise NotImplementedError("Subclasses must implement diffuse") From 107f5ceab6b05587132e4c609264152d5eaa3013 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:31:10 +0800 Subject: [PATCH 18/67] wan_2_2 pipelines Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../models/wan2_2/pipeline_wan2_2.py | 70 ++++++++++++----- .../models/wan2_2/pipeline_wan2_2_i2v.py | 77 +++++++++++++------ .../models/wan2_2/pipeline_wan2_2_ti2v.py | 73 ++++++++++++------ 3 files changed, 152 insertions(+), 68 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 562d8eec51..572865c451 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -18,6 +18,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler @@ -184,7 +185,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class Wan22Pipeline(nn.Module): +class Wan22Pipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -585,25 +586,39 @@ def forward( latent_model_input = latents.to(dtype) timestep = t.expand(latents.shape[0]) - noise_pred = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if current_guidance_scale > 1.0 and negative_prompt_embeds is not None: - noise_uncond = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) - - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) self._current_timestep = None @@ -629,6 +644,21 @@ def forward( return DiffusionOutput(output=output) + def predict_noise(self, current_model=None, **kwargs): + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use (transformer or transformer_2) + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + def encode_prompt( self, prompt: str | list[str], diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index ea045e5d67..37f095ee43 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -18,6 +18,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -136,7 +137,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class Wan22I2VPipeline(nn.Module, SupportImageInput): +class Wan22I2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin): """ Wan2.2 Image-to-Video Pipeline. @@ -484,30 +485,41 @@ def forward( latent_model_input = torch.cat([latents, condition], dim=1).to(dtype) timestep = t.expand(latents.shape[0]) - # Forward pass - noise_pred = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # Classifier-free guidance - if current_guidance_scale > 1.0 and negative_prompt_embeds is not None: - noise_uncond = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) - - # Scheduler step - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "encoder_hidden_states_image": image_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "encoder_hidden_states_image": image_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) self._current_timestep = None @@ -533,6 +545,21 @@ def forward( return DiffusionOutput(output=output) + def predict_noise(self, current_model=None, **kwargs): + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use (transformer or transformer_2) + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + def encode_prompt( self, prompt: str | list[str], diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 5944b678a0..426f2500d7 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -31,6 +31,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -126,7 +127,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class Wan22TI2VPipeline(nn.Module, SupportImageInput): +class Wan22TI2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin): """ Wan2.2 Text-Image-to-Video (TI2V) Pipeline. @@ -399,28 +400,39 @@ def forward( temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) - # Forward pass - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # Classifier-free guidance - if guidance_scale > 1.0 and negative_prompt_embeds is not None: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - # Scheduler step - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": self.transformer, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": self.transformer, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) self._current_timestep = None @@ -446,6 +458,21 @@ def forward( return DiffusionOutput(output=output) + def predict_noise(self, current_model=None, **kwargs): + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + def encode_prompt( self, prompt: str | list[str], From d72429eb4912902822decb0797084d5ad76eeb75 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:32:50 +0800 Subject: [PATCH 19/67] ovis image Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../models/ovis_image/pipeline_ovis_image.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 06d22fea13..3d7376a450 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -478,14 +478,17 @@ def diffuse( "img_ids": latent_image_ids, "return_dict": False, } - negative_kwargs = { - "hidden_states": latents, - "timestep": timestep / 1000, - "encoder_hidden_states": negative_prompt_embeds, - "txt_ids": negative_text_ids, - "img_ids": latent_image_ids, - "return_dict": False, - } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + else: + negative_kwargs = None # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( From be3b8ba53bf5c16fc4cafe42f5898e2374dd0384 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:33:25 +0800 Subject: [PATCH 20/67] sd3 image Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/sd3/pipeline_sd3.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index edeb6d1661..49b44cb729 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -541,13 +541,16 @@ def diffuse( "pooled_projections": pooled_prompt_embeds, "return_dict": False, } - negative_kwargs = { - "hidden_states": latents, - "timestep": timestep, - "encoder_hidden_states": negative_prompt_embeds, - "pooled_projections": negative_pooled_prompt_embeds, - "return_dict": False, - } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "pooled_projections": negative_pooled_prompt_embeds, + "return_dict": False, + } + else: + negative_kwargs = None # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( From 700fb7e97cfc4df115bac40a8058dc92301cf9cb Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:46:09 +0800 Subject: [PATCH 21/67] pass neg kwargs if do_true_cfg Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/distributed/cfg_parallel.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 55bf6702e5..19c946e956 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -301,16 +301,19 @@ def diffuse( "txt_seq_lens": txt_seq_lens, **additional_transformer_kwargs, } - negative_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep / 1000, - "guidance": guidance, - "encoder_hidden_states_mask": negative_prompt_embeds_mask, - "encoder_hidden_states": negative_prompt_embeds, - "img_shapes": img_shapes, - "txt_seq_lens": negative_txt_seq_lens, - **additional_transformer_kwargs, - } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + **additional_transformer_kwargs, + } + else: + negative_kwargs = None # For editing pipelines, we need to slice the output to remove condition latents output_slice = latents.size(1) if image_latents is not None else None From a5c3dfb3a45b84b96712d3c255157041b76d985d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:48:59 +0800 Subject: [PATCH 22/67] flux pipeline reset to main Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../flux2_klein/pipeline_flux2_klein.py | 182 ++++++------------ 1 file changed, 61 insertions(+), 121 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 45db72a494..0496a7ec3a 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -24,19 +24,18 @@ import numpy as np import PIL.Image import torch +import torch.nn as nn from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders.autoencoder_kl_flux2 import AutoencoderKLFlux2 from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor -from torch import nn from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( @@ -179,7 +178,7 @@ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: return float(mu) -class Flux2KleinPipeline(nn.Module, CFGParallelMixin, SupportImageInput): +class Flux2KleinPipeline(nn.Module, SupportImageInput): """Flux2 klein pipeline for text-to-image generation.""" support_image_input = True @@ -645,109 +644,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def diffuse( - self, - latents, - latent_ids, - prompt_embeds, - text_ids, - negative_prompt_embeds, - negative_text_ids, - timesteps, - do_true_cfg, - guidance_scale, - image_latents=None, - image_latent_ids=None, - cfg_normalize=False, - ): - """ - Diffusion loop with optional classifier-free guidance. - - Args: - latents: Noise latents to denoise - latent_ids: Position IDs for latents - prompt_embeds: Positive prompt embeddings - text_ids: Position IDs for positive text - negative_prompt_embeds: Negative prompt embeddings - negative_text_ids: Position IDs for negative text - timesteps: Diffusion timesteps - do_true_cfg: Whether to apply CFG - guidance_scale: CFG scale factor - image_latents: Conditional image latents (default: None) - image_latent_ids: Position IDs for image latents (default: None) - cfg_normalize: Whether to normalize CFG output (default: False) - - Returns: - Denoised latents - """ - self.scheduler.set_begin_index(0) - - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - # Prepare latent model input - latent_model_input = latents.to(self.transformer.dtype) - latent_image_ids = latent_ids - - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) - latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) - - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep / 1000, - "guidance": None, - "encoder_hidden_states": prompt_embeds, - "txt_ids": text_ids, - "img_ids": latent_image_ids, - "joint_attention_kwargs": self.attention_kwargs, - "return_dict": False, - } - negative_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep / 1000, - "guidance": None, - "encoder_hidden_states": negative_prompt_embeds, - "txt_ids": negative_text_ids, - "img_ids": latent_image_ids, - "joint_attention_kwargs": self.attention_kwargs, - "return_dict": False, - } - - # For image conditioning, we need to slice the output to remove condition latents - output_slice = latents.size(1) if image_latents is not None else None - - # Predict noise with automatic CFG parallel handling - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg, - guidance_scale, - positive_kwargs, - negative_kwargs, - cfg_normalize, - output_slice, - ) - - # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - - return latents - - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - return latents - @torch.no_grad() def forward( self, @@ -1010,21 +906,65 @@ def forward( ) self._num_timesteps = len(timesteps) - # 7. Denoising loop using diffuse method - latents = self.diffuse( - latents=latents, - latent_ids=latent_ids, - prompt_embeds=prompt_embeds, - text_ids=text_ids, - negative_prompt_embeds=negative_prompt_embeds if self.do_classifier_free_guidance else None, - negative_text_ids=negative_text_ids if self.do_classifier_free_guidance else None, - timesteps=timesteps, - do_true_cfg=self.do_classifier_free_guidance, - guidance_scale=guidance_scale, - image_latents=image_latents, - image_latent_ids=image_latent_ids, - cfg_normalize=False, # Flux2Klein doesn't use CFG normalization - ) + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + if self.do_classifier_free_guidance: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype and torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) self._current_timestep = None From 2292ccf818a3dfdd34f6c84fed41762f244ddd1e Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:53:46 +0800 Subject: [PATCH 23/67] flux pipeline updates Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../flux2_klein/pipeline_flux2_klein.py | 83 ++++++++++++------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 0496a7ec3a..3ffd292ca3 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -36,6 +36,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( @@ -178,7 +179,7 @@ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: return float(mu) -class Flux2KleinPipeline(nn.Module, SupportImageInput): +class Flux2KleinPipeline(nn.Module, CFGParallelMixin, SupportImageInput): """Flux2 klein pipeline for text-to-image generation.""" support_image_input = True @@ -924,38 +925,44 @@ def forward( latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=None, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - noise_pred = noise_pred[:, : latents.size(1) :] - + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": None, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, + } if self.do_classifier_free_guidance: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=None, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype and torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": None, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, + } + else: + negative_kwargs = None + + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=self.do_classifier_free_guidance, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + output_slice=output_slice, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance) if callback_on_step_end is not None: callback_kwargs = {} @@ -985,6 +992,18 @@ def forward( return DiffusionOutput(output=image) + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + return latents + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) From bd25fea457837c28c14cf7bb7ba5946e3de17383 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:56:09 +0800 Subject: [PATCH 24/67] reset longcat pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../longcat_image/pipeline_longcat_image.py | 167 ++++++------------ 1 file changed, 51 insertions(+), 116 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index e952cd3918..8b616ec45f 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -8,7 +8,6 @@ import os import re from collections.abc import Iterable -from functools import partial from typing import Any import numpy as np @@ -24,7 +23,6 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel @@ -199,7 +197,9 @@ def get_prompt_language(prompt): return "en" -class LongCatImagePipeline(nn.Module, CFGParallelMixin): +class LongCatImagePipeline( + nn.Module, +): def __init__( self, *, @@ -392,105 +392,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor): def do_classifier_free_guidance(self): return self._guidance_scale > 1 - def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0): - """ - Normalize the combined noise prediction. - """ - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) - noise_pred = comb_pred * scale - return noise_pred - - def diffuse( - self, - latents, - latent_image_ids, - prompt_embeds, - text_ids, - negative_prompt_embeds, - negative_text_ids, - timesteps, - do_true_cfg, - guidance_scale, - cfg_normalize=True, - cfg_renorm_min=0.0, - ): - """ - Diffusion loop with optional classifier-free guidance. - - Args: - latents: Noise latents to denoise - latent_image_ids: Position IDs for latents - prompt_embeds: Positive prompt embeddings - text_ids: Position IDs for positive text - negative_prompt_embeds: Negative prompt embeddings - negative_text_ids: Position IDs for negative text - timesteps: Diffusion timesteps - do_true_cfg: Whether to apply CFG - guidance_scale: CFG scale factor - cfg_normalize: Whether to normalize CFG output with custom renorm (default: True) - cfg_renorm_min: Minimum value for CFG renormalization (default: 0.0) - - Returns: - Denoised latents - """ - guidance = None - - self.cfg_normalize_function = partial(self.cfg_normalize_function, cfg_renorm_min=cfg_renorm_min) - - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - positive_kwargs = { - "hidden_states": latents, - "timestep": timestep / 1000, - "guidance": guidance, - "encoder_hidden_states": prompt_embeds, - "txt_ids": text_ids, - "img_ids": latent_image_ids, - "return_dict": False, - } - negative_kwargs = { - "hidden_states": latents, - "timestep": timestep / 1000, - "guidance": guidance, - "encoder_hidden_states": negative_prompt_embeds, - "txt_ids": negative_text_ids, - "img_ids": latent_image_ids, - "return_dict": False, - } - - # Predict noise with automatic CFG parallel handling - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg, - guidance_scale, - positive_kwargs, - negative_kwargs, - cfg_normalize=True, - ) - - # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - - return latents - - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - return latents - def prepare_latents( self, batch_size, @@ -693,6 +594,9 @@ def forward( self._num_timesteps = len(timesteps) + # handle guidance + guidance = None + if self._joint_attention_kwargs is None: self._joint_attention_kwargs = {} @@ -700,20 +604,51 @@ def forward( if self.do_classifier_free_guidance: negative_prompt_embeds = negative_prompt_embeds.to(device) - # 6. Denoising loop using diffuse method - latents = self.diffuse( - latents=latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - text_ids=text_ids, - negative_prompt_embeds=negative_prompt_embeds if self.do_classifier_free_guidance else None, - negative_text_ids=negative_text_ids if self.do_classifier_free_guidance else None, - timesteps=timesteps, - do_true_cfg=self.do_classifier_free_guidance, - guidance_scale=self._guidance_scale, - cfg_normalize=enable_cfg_renorm, - cfg_renorm_min=cfg_renorm_min, - ) + # 6. Denoising loop + for i, t in enumerate(timesteps): + if self._interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred_text = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) + + if enable_cfg_renorm: + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = noise_pred * scale + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) self._current_timestep = None From fbe2b6f597d558df4141a877b348b6604a31c885 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:01:19 +0800 Subject: [PATCH 25/67] updatge longcat pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../longcat_image/pipeline_longcat_image.py | 93 +++++++++++-------- 1 file changed, 56 insertions(+), 37 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 8b616ec45f..3b75c36abf 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -8,6 +8,7 @@ import os import re from collections.abc import Iterable +from functools import partial from typing import Any import numpy as np @@ -23,6 +24,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel @@ -197,9 +199,7 @@ def get_prompt_language(prompt): return "en" -class LongCatImagePipeline( - nn.Module, -): +class LongCatImagePipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -460,6 +460,16 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0): + """ + Normalize the combined noise prediction. + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = comb_pred * scale + return noise_pred + def forward( self, req: OmniDiffusionRequest, @@ -604,6 +614,9 @@ def forward( if self.do_classifier_free_guidance: negative_prompt_embeds = negative_prompt_embeds.to(device) + # custom partial function with cfg_renorm_min + self.cfg_normalize_function = partial(self.cfg_normalize_function, cfg_renorm_min=cfg_renorm_min) + # 6. Denoising loop for i, t in enumerate(timesteps): if self._interrupt: @@ -612,43 +625,37 @@ def forward( self._current_timestep = t timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred_text = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) - - if enable_cfg_renorm: - cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) - noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) - noise_pred = noise_pred * scale + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } else: - noise_pred = noise_pred_text - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=self.do_classifier_free_guidance, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=enable_cfg_renorm, + ) - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance) self._current_timestep = None @@ -665,6 +672,18 @@ def forward( return DiffusionOutput(output=image) + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + return latents + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) From 900d37f7ca6ae7e924f1e9c96396edfe13da34eb Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:02:21 +0800 Subject: [PATCH 26/67] reset longcat edit pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../pipeline_longcat_image_edit.py | 158 +++++------------- 1 file changed, 43 insertions(+), 115 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 24909e72b6..f2c3fd648e 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -26,7 +26,6 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -216,7 +215,7 @@ def split_quotation(prompt, quote_pairs=None): return result -class LongCatImageEditPipeline(nn.Module, CFGParallelMixin, SupportImageInput): +class LongCatImageEditPipeline(nn.Module, SupportImageInput): def __init__( self, *, @@ -397,104 +396,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents - def diffuse( - self, - latents, - image_latents, - latent_image_ids, - prompt_embeds, - text_ids, - negative_prompt_embeds, - negative_text_ids, - timesteps, - do_true_cfg, - guidance_scale, - image_seq_len, - cfg_normalize=False, - ): - """ - Diffusion loop with optional classifier-free guidance. - - Args: - latents: Noise latents to denoise - image_latents: Conditional image latents - latent_image_ids: Position IDs for latents and image - prompt_embeds: Positive prompt embeddings - text_ids: Position IDs for positive text - negative_prompt_embeds: Negative prompt embeddings - negative_text_ids: Position IDs for negative text - timesteps: Diffusion timesteps - do_true_cfg: Whether to apply CFG - guidance_scale: CFG scale factor - image_seq_len: Sequence length of image latents for slicing - cfg_normalize: Whether to normalize CFG output (default: False) - - Returns: - Denoised latents - """ - guidance = None - - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep / 1000, - "guidance": guidance, - "encoder_hidden_states": prompt_embeds, - "txt_ids": text_ids, - "img_ids": latent_image_ids, - "return_dict": False, - } - negative_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep / 1000, - "guidance": guidance, - "encoder_hidden_states": negative_prompt_embeds, - "txt_ids": negative_text_ids, - "img_ids": latent_image_ids, - "return_dict": False, - } - - # For editing pipelines, we need to slice the output to remove condition latents - output_slice = image_seq_len if image_latents is not None else None - - # Predict noise with automatic CFG parallel handling - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg, - guidance_scale, - positive_kwargs, - negative_kwargs, - cfg_normalize, - output_slice, - ) - - # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - - return latents - - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - return latents - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ @@ -736,26 +637,53 @@ def forward( ) self._num_timesteps = len(timesteps) + guidance = None + if image is not None: latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0) else: latent_image_ids = latents_ids - # Denoising loop using diffuse method - latents = self.diffuse( - latents=latents, - image_latents=image_latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - text_ids=text_ids, - negative_prompt_embeds=negative_prompt_embeds if guidance_scale > 1 else None, - negative_text_ids=negative_text_ids if guidance_scale > 1 else None, - timesteps=timesteps, - do_true_cfg=guidance_scale > 1, - guidance_scale=guidance_scale, - image_seq_len=image_seq_len, - cfg_normalize=False, - ) + for i, t in enumerate(timesteps): + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + noise_pred_text = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_text = noise_pred_text[:, :image_seq_len] + if guidance_scale > 1: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_uncond = noise_pred_uncond[:, :image_seq_len] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred_text + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) self._current_timestep = None From 901eb112cf5efd8af2d4ef24c41bf08249c93562 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:07:23 +0800 Subject: [PATCH 27/67] update longcat edit pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../pipeline_longcat_image_edit.py | 78 +++++++++++-------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index f2c3fd648e..be26de538f 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -26,6 +26,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -215,7 +216,7 @@ def split_quotation(prompt, quote_pairs=None): return result -class LongCatImageEditPipeline(nn.Module, SupportImageInput): +class LongCatImageEditPipeline(nn.Module, CFGParallelMixin, SupportImageInput): def __init__( self, *, @@ -652,38 +653,39 @@ def forward( latent_model_input = torch.cat([latents, image_latents], dim=1) timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - - noise_pred_text = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred_text = noise_pred_text[:, :image_seq_len] - if guidance_scale > 1: - noise_pred_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred_uncond = noise_pred_uncond[:, :image_seq_len] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + do_true_cfg = guidance_scale > 1 + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } else: - noise_pred = noise_pred_text - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + output_slice=image_seq_len, + ) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) self._current_timestep = None @@ -699,6 +701,18 @@ def forward( image = self.vae.decode(latents, return_dict=False)[0] return DiffusionOutput(output=image) + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + return latents + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) From 13ff9da316eeeb5c4afc5c1146755915cc6e6050 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:29:21 +0800 Subject: [PATCH 28/67] Zimage pipeline update Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../models/z_image/pipeline_z_image.py | 89 ++++++++++--------- 1 file changed, 46 insertions(+), 43 deletions(-) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 62d92900a2..f07ce9a0a1 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -19,6 +19,7 @@ import json import os from collections.abc import Callable, Iterable +from functools import partial from typing import Any import torch @@ -32,6 +33,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.z_image.z_image_transformer import ( @@ -142,7 +144,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class ZImagePipeline(nn.Module): +class ZImagePipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -527,6 +529,8 @@ def forward( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + self.cfg_normalize_function = partial(self.cfg_normalize_function, actual_batch_size=actual_batch_size) + # 6. Denoising loop for i, t in enumerate(timesteps): if self.interrupt: @@ -550,57 +554,36 @@ def forward( # Run CFG only if configured AND scale is non-zero apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + latents = latents.to(self.transformer.dtype) + positive_kwargs = { + "x": latents, + "t": timestep, + "cap_feats": prompt_embeds, + } if apply_cfg: - latents_typed = latents.to(self.transformer.dtype) - latent_model_input = latents_typed.repeat(2, 1, 1, 1) - prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds - timestep_model_input = timestep.repeat(2) + negative_kwargs = { + "x": latents, + "t": timestep, + "cap_feats": negative_prompt_embeds, + } else: - latent_model_input = latents.to(self.transformer.dtype) - prompt_embeds_model_input = prompt_embeds - timestep_model_input = timestep - - latent_model_input = latent_model_input.unsqueeze(2) - latent_model_input_list = list(latent_model_input.unbind(dim=0)) - - model_out_list = self.transformer( - latent_model_input_list, - timestep_model_input, - prompt_embeds_model_input, - )[0] - - if apply_cfg: - # Perform CFG - pos_out = model_out_list[:actual_batch_size] - neg_out = model_out_list[actual_batch_size:] - - noise_pred = [] - for j in range(actual_batch_size): - pos = pos_out[j].float() - neg = neg_out[j].float() - - pred = pos + current_guidance_scale * (pos - neg) + negative_kwargs = None - # Renormalization - if self._cfg_normalization and float(self._cfg_normalization) > 0.0: - ori_pos_norm = torch.linalg.vector_norm(pos) - new_pos_norm = torch.linalg.vector_norm(pred) - max_new_norm = ori_pos_norm * float(self._cfg_normalization) - if new_pos_norm > max_new_norm: - pred = pred * (max_new_norm / new_pos_norm) - - noise_pred.append(pred) - - noise_pred = torch.stack(noise_pred, dim=0) - else: - noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + cfg_normalize = self._cfg_normalization and float(self._cfg_normalization) > 0.0 + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=apply_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=cfg_normalize, + ).float() noise_pred = noise_pred.squeeze(2) noise_pred = -noise_pred # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, apply_cfg) assert latents.dtype == torch.float32 if callback_on_step_end is not None: @@ -624,6 +607,26 @@ def forward( return DiffusionOutput(output=image) + def cfg_normalize_function(self, noise_pred, comb_pred, actual_batch_size): + assert noise_pred.shape[0] == actual_batch_size, ( + f"Expected noise_pred to have shape ({actual_batch_size}, *), got {noise_pred.shape}" + ) + + noise_pred, comb_pred = noise_pred.float(), comb_pred.float() + norm_pred = [] + for j in range(actual_batch_size): + pos = noise_pred[j] + pred = comb_pred[j] + # Renormalization + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + norm_pred.append(pred) + norm_pred = torch.stack(norm_pred, dim=0) + return norm_pred + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) From b9587cbf0511d939f310e91c538f85d9ee534169 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:52:48 +0800 Subject: [PATCH 29/67] stable audio pipeline edits Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../stable_audio/pipeline_stable_audio.py | 78 ++++++++++++------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index c48d68efd6..b66299a7c7 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -25,6 +25,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportAudioOutput @@ -58,7 +59,7 @@ def post_process_func( return post_process_func -class StableAudioPipeline(nn.Module, SupportAudioOutput): +class StableAudioPipeline(nn.Module, SupportAudioOutput, CFGParallelMixin): """ Pipeline for text-to-audio generation using Stable Audio Open. @@ -319,10 +320,6 @@ def encode_duration( seconds_start_hidden_states = projection_output.seconds_start_hidden_states seconds_end_hidden_states = projection_output.seconds_end_hidden_states - if do_classifier_free_guidance: - seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) - seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) - return seconds_start_hidden_states, seconds_end_hidden_states def prepare_latents( @@ -464,11 +461,20 @@ def forward( batch_size, ) + if do_classifier_free_guidance: + # split prompt_embeds into positive and negative + negative_prompt_embeds, prompt_embeds = prompt_embeds.chunk(2) + # Create combined embeddings text_audio_duration_embeds = torch.cat( [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1, ) + if do_classifier_free_guidance and negative_prompt_embeds is not None: + negative_text_audio_duration_embeds = torch.cat( + [negative_prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], + dim=1, + ) audio_duration_embeds = torch.cat( [seconds_start_hidden_states, seconds_end_hidden_states], dim=2, @@ -477,14 +483,6 @@ def forward( # Handle CFG without negative prompt if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: negative_text_audio_duration_embeds = torch.zeros_like(text_audio_duration_embeds) - text_audio_duration_embeds = torch.cat( - [negative_text_audio_duration_embeds, text_audio_duration_embeds], - dim=0, - ) - audio_duration_embeds = torch.cat( - [audio_duration_embeds, audio_duration_embeds], - dim=0, - ) # Duplicate for multiple waveforms per prompt bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape @@ -532,27 +530,39 @@ def forward( for t in timesteps: self._current_timestep = t - # Expand latents for CFG - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # Predict noise - noise_pred = self.transformer( - latent_model_input, - t.unsqueeze(0), - encoder_hidden_states=text_audio_duration_embeds, - global_hidden_states=audio_duration_embeds, - rotary_embedding=rotary_embedding, - return_dict=False, - )[0] + latent_model_input = self.scheduler.scale_model_input(latents, t) - # Perform CFG + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": t.unsqueeze(0), + "encoder_hidden_states": text_audio_duration_embeds, + "global_hidden_states": audio_duration_embeds, + "rotary_embedding": rotary_embedding, + "return_dict": False, + } if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": t.unsqueeze(0), + "encoder_hidden_states": negative_text_audio_duration_embeds, + "global_hidden_states": audio_duration_embeds, + "rotary_embedding": rotary_embedding, + "return_dict": False, + } + else: + negative_kwargs = None + + # Predict noise + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_classifier_free_guidance, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) # Scheduler step - latents = self.scheduler.step(noise_pred, t, latents).prev_sample + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_classifier_free_guidance) self._current_timestep = None @@ -569,6 +579,14 @@ def forward( return DiffusionOutput(output=audio) + def scheduler_step(self, noise_pred, t, latents): + """ + Step the scheduler. + """ + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + return latents + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) From c2ae71d4bbe1e26a1ab704bf56ecd318c59c6ce8 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 21 Jan 2026 20:14:55 +0800 Subject: [PATCH 30/67] latents .contiguous() Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/cfg_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 19c946e956..69f7018c8b 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -220,6 +220,7 @@ def scheduler_step_maybe_with_cfg(self, noise_pred, t, latents, do_true_cfg): latents = self.scheduler_step(noise_pred, t, latents) # Broadcast the updated latents to all ranks + latents = latents.contiguous() cfg_group.broadcast(latents, src=0) else: # No CFG parallel: directly compute scheduler step From d927698a340b4c143b11486dd3605f767feff43d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 26 Jan 2026 20:10:48 +0800 Subject: [PATCH 31/67] t2v cfg_parallel Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../offline_inference/text_to_video/text_to_video.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 63474987fa..d999c3defd 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -102,6 +102,13 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ring sequence parallelism.", ) + parser.add_argument( + "--cfg_parallel_size", + type=int, + default=1, + choices=[1, 2], + help="Number of GPUs used for classifier free guidance parallel size.", + ) return parser.parse_args() @@ -137,6 +144,9 @@ def main(): # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + parallel_config = DiffusionParallelConfig( + cfg_parallel_size=args.cfg_parallel_size, + ) omni = Omni( model=args.model, enable_layerwise_offload=args.enable_layerwise_offload, From 611eb2a54d64ae6a5bd769d6ae25ab468ac1f857 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 26 Jan 2026 20:29:36 +0800 Subject: [PATCH 32/67] cache empty Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/cfg_parallel.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 69f7018c8b..7023dc4d9f 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -66,12 +66,21 @@ def predict_noise_maybe_with_cfg( local_pred = local_pred[:, :output_slice] gathered = cfg_group.all_gather(local_pred, separate_tensors=True) + + del local_pred + if cfg_rank == 0: noise_pred = gathered[0] neg_noise_pred = gathered[1] noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) + + del gathered, neg_noise_pred + torch.cuda.empty_cache() + return noise_pred else: + del gathered + torch.cuda.empty_cache() return None else: # Sequential CFG: compute both positive and negative From 51566e4d81055796e11abdba76a03363717fb160 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:26:56 +0800 Subject: [PATCH 33/67] correct sd3 pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/sd3/pipeline_sd3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 49b44cb729..1a8cca66f1 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -18,14 +18,14 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device -from vllm_omni.diffusion.model_executor.model_loader.weight_utils import ( - download_weights_from_hf_specific, -) from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.sd3.sd3_transformer import ( SD3Transformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) logger = logging.getLogger(__name__) From b3f0936d53db23436682156101866acb2154e3fd Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 15:29:35 +0800 Subject: [PATCH 34/67] video script with new kwargs Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../image_to_video/image_to_video.py | 19 ++++++++++++++++++- .../text_to_video/text_to_video.py | 1 + 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 1785287849..646e002bfd 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -26,6 +26,7 @@ import PIL.Image import torch +from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -85,6 +86,18 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of ready layers (blocks) to keep on GPU during generation.", ) + parser.add_argument( + "--cfg_parallel_size", + type=int, + default=1, + choices=[1, 2], + help="Number of GPUs used for classifier free guidance parallel size.", + ) + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) return parser.parse_args() @@ -120,7 +133,9 @@ def main(): # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) - + parallel_config = DiffusionParallelConfig( + cfg_parallel_size=args.cfg_parallel_size, + ) omni = Omni( model=args.model, enable_layerwise_offload=args.enable_layerwise_offload, @@ -130,6 +145,8 @@ def main(): boundary_ratio=args.boundary_ratio, flow_shift=args.flow_shift, enable_cpu_offload=args.enable_cpu_offload, + parallel_config=parallel_config, + enforce_eager=args.enforce_eager, ) if profiler_enabled: diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index d999c3defd..ed63d75885 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -109,6 +109,7 @@ def parse_args() -> argparse.Namespace: choices=[1, 2], help="Number of GPUs used for classifier free guidance parallel size.", ) + return parser.parse_args() From b20cfadf30b8c101dcbe5c8471fe7c07a983bf23 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 15:50:09 +0800 Subject: [PATCH 35/67] revert sd audio pipeline change Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../stable_audio/pipeline_stable_audio.py | 78 +++++++------------ 1 file changed, 30 insertions(+), 48 deletions(-) diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index b66299a7c7..c48d68efd6 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -25,7 +25,6 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportAudioOutput @@ -59,7 +58,7 @@ def post_process_func( return post_process_func -class StableAudioPipeline(nn.Module, SupportAudioOutput, CFGParallelMixin): +class StableAudioPipeline(nn.Module, SupportAudioOutput): """ Pipeline for text-to-audio generation using Stable Audio Open. @@ -320,6 +319,10 @@ def encode_duration( seconds_start_hidden_states = projection_output.seconds_start_hidden_states seconds_end_hidden_states = projection_output.seconds_end_hidden_states + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + return seconds_start_hidden_states, seconds_end_hidden_states def prepare_latents( @@ -461,20 +464,11 @@ def forward( batch_size, ) - if do_classifier_free_guidance: - # split prompt_embeds into positive and negative - negative_prompt_embeds, prompt_embeds = prompt_embeds.chunk(2) - # Create combined embeddings text_audio_duration_embeds = torch.cat( [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1, ) - if do_classifier_free_guidance and negative_prompt_embeds is not None: - negative_text_audio_duration_embeds = torch.cat( - [negative_prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], - dim=1, - ) audio_duration_embeds = torch.cat( [seconds_start_hidden_states, seconds_end_hidden_states], dim=2, @@ -483,6 +477,14 @@ def forward( # Handle CFG without negative prompt if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: negative_text_audio_duration_embeds = torch.zeros_like(text_audio_duration_embeds) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], + dim=0, + ) + audio_duration_embeds = torch.cat( + [audio_duration_embeds, audio_duration_embeds], + dim=0, + ) # Duplicate for multiple waveforms per prompt bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape @@ -530,39 +532,27 @@ def forward( for t in timesteps: self._current_timestep = t - latent_model_input = self.scheduler.scale_model_input(latents, t) - - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": t.unsqueeze(0), - "encoder_hidden_states": text_audio_duration_embeds, - "global_hidden_states": audio_duration_embeds, - "rotary_embedding": rotary_embedding, - "return_dict": False, - } - if do_classifier_free_guidance: - negative_kwargs = { - "hidden_states": latent_model_input, - "timestep": t.unsqueeze(0), - "encoder_hidden_states": negative_text_audio_duration_embeds, - "global_hidden_states": audio_duration_embeds, - "rotary_embedding": rotary_embedding, - "return_dict": False, - } - else: - negative_kwargs = None + # Expand latents for CFG + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # Predict noise - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=do_classifier_free_guidance, - true_cfg_scale=guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - cfg_normalize=False, - ) + noise_pred = self.transformer( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + )[0] + + # Perform CFG + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # Scheduler step - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_classifier_free_guidance) + latents = self.scheduler.step(noise_pred, t, latents).prev_sample self._current_timestep = None @@ -579,14 +569,6 @@ def forward( return DiffusionOutput(output=audio) - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - latents = self.scheduler.step(noise_pred, t, latents).prev_sample - - return latents - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) From e3e1ba28db59a0a4888240a6506d6cfaa7693ac6 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:11:33 +0800 Subject: [PATCH 36/67] revert zimage pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../models/z_image/pipeline_z_image.py | 89 +++++++++---------- 1 file changed, 43 insertions(+), 46 deletions(-) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index f07ce9a0a1..62d92900a2 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -19,7 +19,6 @@ import json import os from collections.abc import Callable, Iterable -from functools import partial from typing import Any import torch @@ -33,7 +32,6 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.z_image.z_image_transformer import ( @@ -144,7 +142,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class ZImagePipeline(nn.Module, CFGParallelMixin): +class ZImagePipeline(nn.Module): def __init__( self, *, @@ -529,8 +527,6 @@ def forward( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - self.cfg_normalize_function = partial(self.cfg_normalize_function, actual_batch_size=actual_batch_size) - # 6. Denoising loop for i, t in enumerate(timesteps): if self.interrupt: @@ -554,36 +550,57 @@ def forward( # Run CFG only if configured AND scale is non-zero apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 - latents = latents.to(self.transformer.dtype) - positive_kwargs = { - "x": latents, - "t": timestep, - "cap_feats": prompt_embeds, - } if apply_cfg: - negative_kwargs = { - "x": latents, - "t": timestep, - "cap_feats": negative_prompt_embeds, - } + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) else: - negative_kwargs = None + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) - cfg_normalize = self._cfg_normalization and float(self._cfg_normalization) > 0.0 - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=apply_cfg, - true_cfg_scale=current_guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - cfg_normalize=cfg_normalize, - ).float() + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) noise_pred = noise_pred.squeeze(2) noise_pred = -noise_pred # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, apply_cfg) + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] assert latents.dtype == torch.float32 if callback_on_step_end is not None: @@ -607,26 +624,6 @@ def forward( return DiffusionOutput(output=image) - def cfg_normalize_function(self, noise_pred, comb_pred, actual_batch_size): - assert noise_pred.shape[0] == actual_batch_size, ( - f"Expected noise_pred to have shape ({actual_batch_size}, *), got {noise_pred.shape}" - ) - - noise_pred, comb_pred = noise_pred.float(), comb_pred.float() - norm_pred = [] - for j in range(actual_batch_size): - pos = noise_pred[j] - pred = comb_pred[j] - # Renormalization - ori_pos_norm = torch.linalg.vector_norm(pos) - new_pos_norm = torch.linalg.vector_norm(pred) - max_new_norm = ori_pos_norm * float(self._cfg_normalization) - if new_pos_norm > max_new_norm: - pred = pred * (max_new_norm / new_pos_norm) - norm_pred.append(pred) - norm_pred = torch.stack(norm_pred, dim=0) - return norm_pred - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) From 664d76daf239c3a43e351c10a5fa2cb6f0582870 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:22:07 +0800 Subject: [PATCH 37/67] support list update Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/diffusion_acceleration.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 8d31747d21..70f44ee8bd 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -39,23 +39,23 @@ The following table shows which models are currently supported by each accelerat | Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | |-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ❌ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | | **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | -| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ❌ | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | | **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ### VideoGen -| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | -|-------|------------------|:--------:|:---------:|:----------:|:--------------:| -| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ❌ | ❌ | +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel | +|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:| +| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ❌ | ❌ | ✅ | ## Performance Benchmarks From 50e49b185e397c0a96c7fa7ace90d83429b3e840 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:28:16 +0800 Subject: [PATCH 38/67] update how-to-parallelize-a-pipeline Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/parallelism_acceleration.md | 167 ++++++++++++------ 1 file changed, 117 insertions(+), 50 deletions(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 1a9f540ad3..fbebe6f2c6 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -416,58 +416,125 @@ In `QwenImagePipeline`, each diffusion step runs two denoiser forward passes seq CFG-Parallel assigns these two branches to different ranks in the **CFG group** and synchronizes the results. -Below is an example of CFG-Parallel implementation: +vLLM-omni provides `CFGParallelMixin` base class that encapsulates the CFG parallel logic. By inheriting from this mixin and calling its methods, pipelines can easily implement CFG parallel without writing repetitive code. + +**Key Methods in CFGParallelMixin:** +- `predict_noise_maybe_with_cfg()`: Automatically handles CFG parallel noise prediction +- `scheduler_step_maybe_with_cfg()`: Scheduler step with automatic CFG rank synchronization + +**Example Implementation:** ```python -def diffuse( +class QwenImageCFGParallelMixin(CFGParallelMixin): + def diffuse( self, + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + latents, + img_shapes, + txt_seq_lens, + negative_txt_seq_lens, + timesteps, + do_true_cfg, + guidance, + true_cfg_scale, + cfg_normalize=True, ... - ): - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - self.transformer.do_true_cfg = do_true_cfg - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - else: - local_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - else: - # fallback: run positive then negative sequentially on one rank - ... + ): + self.transformer.do_true_cfg = do_true_cfg + + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + # Prepare kwargs for positive (conditional) prediction + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": prompt_embeds_mask, + "encoder_hidden_states": prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + # Prepare kwargs for negative (unconditional) prediction + if do_true_cfg: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + # - In CFG parallel mode: rank0 computes positive, rank1 computes negative + # - Automatically gathers results and combines them on rank0 + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=true_cfg_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=cfg_normalize, + ) + + # Step scheduler with automatic CFG synchronization + # - Only rank0 computes the scheduler step + # - Automatically broadcasts updated latents to all ranks + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg + ) + + return latents +``` + +**How it works:** +1. Prepare separate `positive_kwargs` and `negative_kwargs` for conditional and unconditional predictions +2. Call `predict_noise_maybe_with_cfg()` which: + - Detects if CFG parallel is enabled (`get_classifier_free_guidance_world_size() > 1`) + - Distributes computation: rank0 processes positive, rank1 processes negative + - Gathers predictions and combines them using `combine_cfg_noise()` on rank0 + - Returns combined noise prediction (only valid on rank0) +3. Call `scheduler_step_maybe_with_cfg()` which: + - Only rank0 computes the scheduler step + - Broadcasts the updated latents to all ranks for synchronization + +**How to customize** + +Some pipelines may need to customize the following functions in `CFGParallelMixin`: +1. You may need to edit `predict_noise` function for custom behaviors. +```python +def predict_noise(self, *args, **kwargs): + """ + Forward pass through transformer to predict noise. + + Subclasses should override this if they need custom behavior, + but the default implementation calls self.transformer. + """ + return self.transformer(*args, **kwargs)[0] + +``` +2. The default normalization function after combining the noise predictions from both branches is as follows. You may need to customize it. +```python +def cfg_normalize_function(self, noise_pred, comb_pred): + """ + Normalize the combined noise prediction. + + Args: + noise_pred: positive noise prediction + comb_pred: combined noise prediction after CFG + + Returns: + Normalized noise prediction tensor + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + return noise_pred ``` From e1ed60807875f86449ba2efca396d872f28bbf07 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:38:38 +0800 Subject: [PATCH 39/67] support list update wan2.2 Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/diffusion_acceleration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 70f44ee8bd..d081243782 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -55,7 +55,7 @@ The following table shows which models are currently supported by each accelerat | Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel | |-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:| -| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | ## Performance Benchmarks From fbf4837a51be67e964c4b7b9e45e6f33943bd026 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 17:20:31 +0800 Subject: [PATCH 40/67] revise document head Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/cfg_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 7023dc4d9f..ba0cc6c087 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Base pipeline class for Qwen Image models with shared CFG functionality. +Base pipeline class for Diffusion models with shared CFG functionality. """ from abc import ABCMeta From cb5f7703b5d717000778c272269b2fa7629b9105 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 27 Jan 2026 17:36:55 +0800 Subject: [PATCH 41/67] fix t2v args Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- examples/offline_inference/text_to_video/text_to_video.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index ed63d75885..8d14e93c2c 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -140,14 +140,12 @@ def main(): parallel_config = DiffusionParallelConfig( ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, + cfg_parallel_size=args.cfg_parallel_size, ) # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) - parallel_config = DiffusionParallelConfig( - cfg_parallel_size=args.cfg_parallel_size, - ) omni = Omni( model=args.model, enable_layerwise_offload=args.enable_layerwise_offload, From aa3c337eedfa805cbc511225c471976535a1e95a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:49:40 +0800 Subject: [PATCH 42/67] fix parameter annotation Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/distributed/cfg_parallel.py | 67 ++++++++++--------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index ba0cc6c087..f88a28a1c1 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -6,6 +6,7 @@ """ from abc import ABCMeta +from typing import Any import torch @@ -26,13 +27,13 @@ class CFGParallelMixin(metaclass=ABCMeta): def predict_noise_maybe_with_cfg( self, - do_true_cfg, - true_cfg_scale, - positive_kwargs, - negative_kwargs, - cfg_normalize=True, - output_slice=None, - ): + do_true_cfg: bool, + true_cfg_scale: float, + positive_kwargs: dict[str, Any], + negative_kwargs: dict[str, Any] | None, + cfg_normalize: bool = True, + output_slice: int | None = None, + ) -> torch.Tensor | None: """ Predict noise with optional classifier-free guidance. @@ -103,7 +104,7 @@ def predict_noise_maybe_with_cfg( pred = pred[:, :output_slice] return pred - def cfg_normalize_function(self, noise_pred, comb_pred): + def cfg_normalize_function(self, noise_pred: torch.Tensor, comb_pred: torch.Tensor) -> torch.Tensor: """ Normalize the combined noise prediction. @@ -119,7 +120,9 @@ def cfg_normalize_function(self, noise_pred, comb_pred): noise_pred = comb_pred * (cond_norm / noise_norm) return noise_pred - def combine_cfg_noise(self, noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize=False): + def combine_cfg_noise( + self, noise_pred: torch.Tensor, neg_noise_pred: torch.Tensor, true_cfg_scale: float, cfg_normalize: bool = False + ) -> torch.Tensor: """ Combine conditional and unconditional noise predictions with CFG. @@ -141,7 +144,7 @@ def combine_cfg_noise(self, noise_pred, neg_noise_pred, true_cfg_scale, cfg_norm return noise_pred - def predict_noise(self, *args, **kwargs): + def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor: """ Forward pass through transformer to predict noise. @@ -152,9 +155,9 @@ def predict_noise(self, *args, **kwargs): def diffuse( self, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> Any: """ Diffusion loop with optional classifier-free guidance. @@ -187,7 +190,7 @@ def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds, ...): """ raise NotImplementedError("Subclasses must implement diffuse") - def scheduler_step(self, noise_pred, t, latents): + def scheduler_step(self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: """ Step the scheduler. @@ -201,7 +204,9 @@ def scheduler_step(self, noise_pred, t, latents): """ return self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - def scheduler_step_maybe_with_cfg(self, noise_pred, t, latents, do_true_cfg): + def scheduler_step_maybe_with_cfg( + self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor, do_true_cfg: bool + ) -> torch.Tensor: """ Step the scheduler with (maybe) automatic CFG parallel synchronization. @@ -245,22 +250,22 @@ class QwenImageCFGParallelMixin(CFGParallelMixin): def diffuse( self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - image_latents=None, - cfg_normalize=True, - additional_transformer_kwargs=None, - ): + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + latents: torch.Tensor, + img_shapes: torch.Tensor, + txt_seq_lens: torch.Tensor, + negative_txt_seq_lens: torch.Tensor, + timesteps: torch.Tensor, + do_true_cfg: bool, + guidance: torch.Tensor, + true_cfg_scale: float, + image_latents: torch.Tensor | None = None, + cfg_normalize: bool = True, + additional_transformer_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: """ Diffusion loop with optional classifier-free guidance. From 444e52550920e59bd686aec71a08f08c4a58f787 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:27:14 +0800 Subject: [PATCH 43/67] fix parameter annotation Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../models/ovis_image/pipeline_ovis_image.py | 22 +++++++++---------- .../diffusion/models/sd3/pipeline_sd3.py | 20 ++++++++--------- .../models/wan2_2/pipeline_wan2_2_ti2v.py | 4 ++-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 3d7376a450..db87386804 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -432,17 +432,17 @@ def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): def diffuse( self, - latents, - timesteps, - prompt_embeds, - negative_prompt_embeds, - text_ids, - negative_text_ids, - latent_image_ids, - do_true_cfg, - guidance_scale, - cfg_normalize=False, - ): + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + text_ids: torch.Tensor, + negative_text_ids: torch.Tensor, + latent_image_ids: torch.Tensor, + do_true_cfg: bool, + guidance_scale: float, + cfg_normalize: bool = False, + ) -> torch.Tensor: """ Diffusion loop with optional classifier-free guidance. diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 1a8cca66f1..3668c132f5 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -497,16 +497,16 @@ def interrupt(self): def diffuse( self, - latents, - timesteps, - prompt_embeds, - pooled_prompt_embeds, - negative_prompt_embeds, - negative_pooled_prompt_embeds, - do_true_cfg, - guidance_scale, - cfg_normalize=False, - ): + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_pooled_prompt_embeds: torch.Tensor, + do_true_cfg: bool, + guidance_scale: float, + cfg_normalize: bool = False, + ) -> torch.Tensor: """ Diffusion loop with optional classifier-free guidance. diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 426f2500d7..305507f453 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -19,7 +19,7 @@ import logging import os from collections.abc import Iterable -from typing import cast +from typing import cast, Any import numpy as np import PIL.Image @@ -458,7 +458,7 @@ def forward( return DiffusionOutput(output=output) - def predict_noise(self, current_model=None, **kwargs): + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: """ Forward pass through transformer to predict noise. From 63f64d307e531ebb2d7c252d90325fd1adf2529c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:32:56 +0800 Subject: [PATCH 44/67] empty_cache when cuda is available() Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/cfg_parallel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index f88a28a1c1..27898610f8 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -76,12 +76,14 @@ def predict_noise_maybe_with_cfg( noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) del gathered, neg_noise_pred - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return noise_pred else: del gathered - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return None else: # Sequential CFG: compute both positive and negative From 1767766ca48f768e56e0417f682dda24245290af Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:45:17 +0800 Subject: [PATCH 45/67] test unit Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../distributed/test_cfg_parallel.py | 431 ++++++++++++++++++ 1 file changed, 431 insertions(+) create mode 100644 tests/diffusion/distributed/test_cfg_parallel.py diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py new file mode 100644 index 0000000000..0fc28023ee --- /dev/null +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for CFG (Classifier-Free Guidance) parallel functionality. + +This test verifies that predict_noise_maybe_with_cfg produces numerically +equivalent results with and without CFG parallel using fixed random inputs. +""" + +import os +from tkinter.constants import X + +import pytest +import torch + +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.utils.platform_utils import detect_device_type + +device_type = detect_device_type() +if device_type == "cuda": + torch_device = torch.cuda +elif device_type == "npu": + torch_device = torch.npu +else: + raise ValueError(f"Unsupported device type: {device_type} for this test script! Expected GPU or NPU.") + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + +class SimpleTransformer(torch.nn.Module): + """Simple transformer model for testing with random initialization. + + Contains: + - Input projection (conv to hidden_dim) + - QKV projection layers + - Self-attention layer + - Output projection + """ + + def __init__(self, in_channels: int = 4, hidden_dim: int = 128, num_heads: int = 8): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + + assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" + + # Input projection: (B, C, H, W) -> (B, hidden_dim, H, W) + self.input_proj = torch.nn.Conv2d(in_channels, hidden_dim, 1) + + # QKV projection layers + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim) + + # Output projection after attention + self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim) + + # Final output projection: (B, hidden_dim, H, W) -> (B, C, H, W) + self.final_proj = torch.nn.Conv2d(hidden_dim, in_channels, 1) + + # Layer norm + self.norm1 = torch.nn.LayerNorm(hidden_dim) + self.norm2 = torch.nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: + """Forward pass with self-attention. + + Args: + x: Input tensor of shape (B, C, H, W) + + Returns: + Output tensor of shape (B, C, H, W) + """ + B, C, H, W = x.shape + + # Input projection + x = self.input_proj(x) # (B, hidden_dim, H, W) + + # Reshape to sequence: (B, hidden_dim, H, W) -> (B, H*W, hidden_dim) + x = x.flatten(2).transpose(1, 2) # (B, H*W, hidden_dim) + + # Self-attention with residual connection + residual = x + x = self.norm1(x) + + # QKV projection + q = self.q_proj(x) # (B, H*W, hidden_dim) + k = self.k_proj(x) # (B, H*W, hidden_dim) + v = self.v_proj(x) # (B, H*W, hidden_dim) + + # Reshape for multi-head attention: (B, H*W, hidden_dim) -> (B, num_heads, H*W, head_dim) + seq_len = H * W + q = q.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # Scaled dot-product attention + scale = self.head_dim**-0.5 + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (B, num_heads, H*W, H*W) + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous().view(B, seq_len, self.hidden_dim) + + attn_output = self.out_proj(attn_output) + + x = residual + attn_output + residual = x + x = self.norm2(x) + x = residual + X + x = x.transpose(1, 2).view(B, self.hidden_dim, H, W) + + out = self.final_proj(x) + + return (out,) + + +class TestCFGPipeline(CFGParallelMixin): + """Test pipeline using CFGParallelMixin.""" + + def __init__(self, in_channels: int = 4, hidden_dim: int = 128, seed: int = 42): + # Set seed BEFORE creating transformer to ensure consistent layer initialization + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + self.transformer = SimpleTransformer(in_channels, hidden_dim) + + # Re-initialize all parameters with fixed seed for full reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + for param in self.transformer.parameters(): + torch.nn.init.normal_(param, mean=0.0, std=0.02) + + +def _test_cfg_parallel_worker( + local_rank: int, + world_size: int, + cfg_parallel_size: int, + dtype: torch.dtype, + test_config: dict, +): + """Worker function for CFG parallel test.""" + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29502", + } + ) + + init_distributed_environment() + initialize_model_parallel(cfg_parallel_size=cfg_parallel_size) + + cfg_rank = get_classifier_free_guidance_rank() + cfg_world_size = get_classifier_free_guidance_world_size() + + assert cfg_world_size == cfg_parallel_size + + # Create pipeline with same seed to ensure identical model weights across all ranks + # Note: model_seed is set inside TestCFGPipeline.__init__ + pipeline = TestCFGPipeline( + in_channels=test_config["channels"], + hidden_dim=test_config["hidden_dim"], + seed=test_config["model_seed"], + ) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() # Set to eval mode for deterministic behavior + + # Create fixed inputs with explicit seed setting for reproducibility + # Set both CPU and CUDA seeds to ensure identical inputs across all ranks + torch.manual_seed(test_config["input_seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"]) + + batch_size = test_config["batch_size"] + channels = test_config["channels"] + height = test_config["height"] + width = test_config["width"] + + # Positive input + positive_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Negative input with different seed + torch.manual_seed(test_config["input_seed"] + 1) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"] + 1) + negative_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Prepare kwargs for predict_noise_maybe_with_cfg + positive_kwargs = {"x": positive_input} + negative_kwargs = {"x": negative_input} + + with torch.no_grad(): + # Call predict_noise_maybe_with_cfg + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=test_config["cfg_scale"], + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=test_config["cfg_normalize"], + ) + + # Only rank 0 has valid output in CFG parallel mode + if cfg_rank == 0: + assert noise_pred is not None + result_path = test_config["result_path"] + torch.save(noise_pred.cpu(), result_path) + else: + assert noise_pred is None + + destroy_distributed_env() + + +def _test_cfg_sequential_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + test_config: dict, +): + """Worker function for sequential CFG test (baseline).""" + device = torch.device(f"{device_type}:{local_rank}") + torch_device.set_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29503", + } + ) + + init_distributed_environment() + initialize_model_parallel(cfg_parallel_size=1) # No CFG parallel + + cfg_world_size = get_classifier_free_guidance_world_size() + assert cfg_world_size == 1 + + # Create pipeline with same seed to ensure identical model weights as CFG parallel + # Note: model_seed is set inside TestCFGPipeline.__init__ + pipeline = TestCFGPipeline( + in_channels=test_config["channels"], + hidden_dim=test_config["hidden_dim"], + seed=test_config["model_seed"], + ) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() + + # Create fixed inputs (same seed as CFG parallel to ensure identical inputs) + # Set both CPU and CUDA seeds for full reproducibility + torch.manual_seed(test_config["input_seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"]) + + batch_size = test_config["batch_size"] + channels = test_config["channels"] + height = test_config["height"] + width = test_config["width"] + + # Positive input + positive_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Negative input with different seed + torch.manual_seed(test_config["input_seed"] + 1) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"] + 1) + negative_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + positive_kwargs = {"x": positive_input} + negative_kwargs = {"x": negative_input} + + with torch.no_grad(): + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=test_config["cfg_scale"], + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=test_config["cfg_normalize"], + ) + + # Sequential CFG always returns output + assert noise_pred is not None + result_path = test_config["baseline_path"] + torch.save(noise_pred.cpu(), result_path) + + destroy_distributed_env() + + +@pytest.mark.parametrize("cfg_parallel_size", [2]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("cfg_normalize", [False, True]) +def test_predict_noise_maybe_with_cfg( + cfg_parallel_size: int, dtype: torch.dtype, batch_size: int, cfg_normalize: bool, tmp_path +): + """ + Test that predict_noise_maybe_with_cfg produces identical results + with and without CFG parallel. + + Args: + cfg_parallel_size: Number of GPUs for CFG parallel + dtype: Data type for computation + batch_size: Batch size for testing + cfg_normalize: Whether to normalize CFG output + tmp_path: Temporary directory for storing results + """ + available_gpus = torch_device.device_count() + if available_gpus < cfg_parallel_size: + pytest.skip(f"Test requires {cfg_parallel_size} GPUs but only {available_gpus} available") + + test_config = { + "batch_size": batch_size, + "channels": 4, + "height": 16, + "width": 16, + "hidden_dim": 128, + "cfg_scale": 7.5, + "cfg_normalize": cfg_normalize, + "model_seed": 42, # Fixed seed for model initialization + "input_seed": 123, # Fixed seed for input generation + "baseline_path": str(tmp_path / "baseline.pt"), + "result_path": str(tmp_path / "cfg_parallel.pt"), + } + + # Run baseline (sequential CFG) on single GPU + torch.multiprocessing.spawn( + _test_cfg_sequential_worker, + args=(1, dtype, test_config), + nprocs=1, + ) + + # Run CFG parallel on multiple GPUs + torch.multiprocessing.spawn( + _test_cfg_parallel_worker, + args=(cfg_parallel_size, cfg_parallel_size, dtype, test_config), + nprocs=cfg_parallel_size, + ) + + # Load and compare results + baseline_output = torch.load(test_config["baseline_path"]) + cfg_parallel_output = torch.load(test_config["result_path"]) + + # Verify shapes match + assert baseline_output.shape == cfg_parallel_output.shape, ( + f"Shape mismatch: baseline {baseline_output.shape} vs CFG parallel {cfg_parallel_output.shape}" + ) + + # Verify numerical equivalence with appropriate tolerances + if dtype == torch.float32: + rtol, atol = 1e-5, 1e-5 + elif dtype == torch.bfloat16: + rtol, atol = 1e-2, 1e-2 + else: + rtol, atol = 1e-3, 1e-3 + + torch.testing.assert_close( + cfg_parallel_output, + baseline_output, + rtol=rtol, + atol=atol, + msg=( + f"CFG parallel output differs from sequential CFG\n" + f" dtype={dtype}, batch_size={batch_size}, cfg_normalize={cfg_normalize}\n" + f" Max diff: {(cfg_parallel_output - baseline_output).abs().max().item():.6e}" + ), + ) + + print( + f"✓ Test passed: cfg_size={cfg_parallel_size}, dtype={dtype}, " + f"batch_size={batch_size}, cfg_normalize={cfg_normalize}" + ) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_predict_noise_without_cfg(dtype: torch.dtype): + """ + Test predict_noise_maybe_with_cfg when do_true_cfg=False. + + When CFG is disabled, only the positive branch should be computed. + This test runs on a single GPU without distributed environment. + """ + available_gpus = torch_device.device_count() + if available_gpus < 1: + pytest.skip("Test requires at least 1 GPU") + + device = torch.device(f"{device_type}:0") + torch_device.set_device(device) + + # Create pipeline without distributed environment + pipeline = TestCFGPipeline(in_channels=4, hidden_dim=128, seed=42) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() + + # Set seed for input generation + torch.manual_seed(123) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(123) + positive_input = torch.randn(1, 4, 16, 16, dtype=dtype, device=device) + + with torch.no_grad(): + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=False, # No CFG + true_cfg_scale=7.5, + positive_kwargs={"x": positive_input}, + negative_kwargs=None, + cfg_normalize=False, + ) + + # Should always return output when do_true_cfg=False + assert noise_pred is not None + assert noise_pred.shape == (1, 4, 16, 16) + + print(f"✓ Test passed: predict_noise without CFG (dtype={dtype})") From 9b948862ad94261d05e0c48979eff8aa8fa6f8f7 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:57:38 +0800 Subject: [PATCH 46/67] fix parameter annotation Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 5 +++-- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 572865c451..eaa59f0bde 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -7,7 +7,8 @@ import logging import os from collections.abc import Iterable -from typing import cast +from re import A +from typing import cast, Any import PIL.Image import torch @@ -644,7 +645,7 @@ def forward( return DiffusionOutput(output=output) - def predict_noise(self, current_model=None, **kwargs): + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: """ Forward pass through transformer to predict noise. diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 37f095ee43..626fa7acf8 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -6,7 +6,7 @@ import logging import os from collections.abc import Iterable -from typing import cast +from typing import cast, Any import numpy as np import PIL.Image @@ -545,7 +545,7 @@ def forward( return DiffusionOutput(output=output) - def predict_noise(self, current_model=None, **kwargs): + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: """ Forward pass through transformer to predict noise. From 8662f8e0477772b2720aa5d62705a7a7e309a000 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 28 Jan 2026 12:43:41 +0800 Subject: [PATCH 47/67] update unit test Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/diffusion/distributed/test_cfg_parallel.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py index 0fc28023ee..587b513c91 100644 --- a/tests/diffusion/distributed/test_cfg_parallel.py +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -7,7 +7,6 @@ """ import os -from tkinter.constants import X import pytest import torch @@ -117,7 +116,7 @@ def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: x = residual + attn_output residual = x x = self.norm2(x) - x = residual + X + x = residual + x x = x.transpose(1, 2).view(B, self.hidden_dim, H, W) out = self.final_proj(x) @@ -305,8 +304,8 @@ def _test_cfg_sequential_worker( @pytest.mark.parametrize("cfg_parallel_size", [2]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("cfg_normalize", [False, True]) def test_predict_noise_maybe_with_cfg( cfg_parallel_size: int, dtype: torch.dtype, batch_size: int, cfg_normalize: bool, tmp_path @@ -389,7 +388,7 @@ def test_predict_noise_maybe_with_cfg( ) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_predict_noise_without_cfg(dtype: torch.dtype): """ Test predict_noise_maybe_with_cfg when do_true_cfg=False. From 0b0d71d3e744759d16d850e21181614f3bf21c0c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 28 Jan 2026 17:42:55 +0800 Subject: [PATCH 48/67] fix pre-commit error Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 3 +-- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 2 +- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index eaa59f0bde..8979a47f88 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -7,8 +7,7 @@ import logging import os from collections.abc import Iterable -from re import A -from typing import cast, Any +from typing import Any, cast import PIL.Image import torch diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 626fa7acf8..1e166d3482 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -6,7 +6,7 @@ import logging import os from collections.abc import Iterable -from typing import cast, Any +from typing import Any, cast import numpy as np import PIL.Image diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 305507f453..1a2b94f8a3 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -19,7 +19,7 @@ import logging import os from collections.abc import Iterable -from typing import cast, Any +from typing import Any, cast import numpy as np import PIL.Image From cfbd49dcc3c91b69fe7bbf88d46a0fd7cc15e289 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:20:35 +0800 Subject: [PATCH 49/67] fix pre-commit error Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/diffusion/distributed/test_cfg_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py index 587b513c91..2096c48d93 100644 --- a/tests/diffusion/distributed/test_cfg_parallel.py +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -10,6 +10,7 @@ import pytest import torch +from vllm_omni.utils.platform_utils import detect_device_type from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.parallel_state import ( @@ -19,7 +20,6 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm_omni.utils.platform_utils import detect_device_type device_type = detect_device_type() if device_type == "cuda": From 934569456584907028fe123ea6822d73714ae422 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:21:58 +0800 Subject: [PATCH 50/67] check cfg_parallel size in data.py Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 4248071f01..f884fb6f17 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -55,6 +55,7 @@ def _validate_parallel_config(self) -> Self: assert self.ulysses_degree > 0, "Ulysses degree must be > 0" assert self.ring_degree > 0, "Ring degree must be > 0" assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" + assert self.cfg_parallel_size in [1, 2], f"CFG parallel size must be 1 or 2, but got {self.cfg_parallel_size}" assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, ( "Sequence parallel size must be equal to the product of ulysses degree and ring degree," f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" From c8bcf2ead570f2d062093b274f653f64516c5367 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:29:54 +0800 Subject: [PATCH 51/67] update cfg_parallel_size arg doc Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- examples/offline_inference/image_to_image/image_to_image.md | 1 + examples/offline_inference/image_to_video/README.md | 1 + examples/offline_inference/text_to_image/README.md | 1 + examples/offline_inference/text_to_video/text_to_video.md | 1 + 4 files changed, 4 insertions(+) diff --git a/examples/offline_inference/image_to_image/image_to_image.md b/examples/offline_inference/image_to_image/image_to_image.md index b9e25bf986..d0986d6ee7 100644 --- a/examples/offline_inference/image_to_image/image_to_image.md +++ b/examples/offline_inference/image_to_image/image_to_image.md @@ -49,6 +49,7 @@ Key arguments: - `--output`: path to save the generated PNG. - `--vae_use_slicing`: enable VAE slicing for memory optimization. - `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. > ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md index 52bb389e73..a1355dab69 100644 --- a/examples/offline_inference/image_to_video/README.md +++ b/examples/offline_inference/image_to_video/README.md @@ -56,6 +56,7 @@ Key arguments: - `--output`: Path to save the generated video. - `--vae_use_slicing`: Enable VAE slicing for memory optimization. - `--vae_use_tiling`: Enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. > ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index ab28b115c1..9c57a621cf 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -98,6 +98,7 @@ Key arguments: - `--output`: path to save the generated PNG. - `--vae_use_slicing`: enable VAE slicing for memory optimization. - `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. > ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/text_to_video/text_to_video.md b/examples/offline_inference/text_to_video/text_to_video.md index b5d0b2adc2..04f1a2653b 100644 --- a/examples/offline_inference/text_to_video/text_to_video.md +++ b/examples/offline_inference/text_to_video/text_to_video.md @@ -31,6 +31,7 @@ Key arguments: - `--output`: path to save the generated video. - `--vae_use_slicing`: enable VAE slicing for memory optimization. - `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. > ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. From 35dbdeb7db603f7f63bdca87c75fa50acb949d73 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:34:36 +0800 Subject: [PATCH 52/67] doc refinement Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/diffusion/parallelism_acceleration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index fbebe6f2c6..0d17b7e0da 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -375,7 +375,7 @@ def forward(self, hidden_states, ...): ##### Offline Inference -CFG-Parallel is enabled through `DiffusionParallelConfig(cfg_parallel_size=...)`. The recommended configuration is `cfg_parallel_size=2` (one rank for the positive branch and one rank for the negative branch). +CFG-Parallel is enabled through `DiffusionParallelConfig(cfg_parallel_size=2)`, which runs one rank for the positive branch and one rank for the negative branch. An example of offline inference using CFG-Parallel (image-to-image) is shown below: @@ -403,7 +403,7 @@ outputs = omni.generate( Notes: -- CFG-Parallel is only effective when **true CFG** is enabled (i.e., `true_cfg_scale > 1` and a `negative_prompt` is provided). +- CFG-Parallel is only effective when a `negative_prompt` is provided AND a guidance scale (or `cfg_scale`) is greater than 1. #### How to parallelize a pipeline From f3a54fe1a0ac5b48a441ce34dc11e5316187b42a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:39:27 +0800 Subject: [PATCH 53/67] update doc with new arg Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/examples/offline_inference/image_to_image.md | 1 + docs/user_guide/examples/offline_inference/image_to_video.md | 1 + docs/user_guide/examples/offline_inference/text_to_image.md | 1 + docs/user_guide/examples/offline_inference/text_to_video.md | 1 + 4 files changed, 4 insertions(+) diff --git a/docs/user_guide/examples/offline_inference/image_to_image.md b/docs/user_guide/examples/offline_inference/image_to_image.md index 78c7f11fcb..c970106d2b 100644 --- a/docs/user_guide/examples/offline_inference/image_to_image.md +++ b/docs/user_guide/examples/offline_inference/image_to_image.md @@ -47,6 +47,7 @@ Key arguments: - `--image`: path(s) to the source image(s) (PNG/JPG, converted to RGB). Can specify multiple images. - `--prompt` / `--negative_prompt`: text description (string). - `--cfg_scale`: true classifier-free guidance scale (default: 4.0). Classifier-free guidance is enabled by setting cfg_scale > 1 and providing a negative_prompt. Higher guidance scale encourages images closely linked to the text prompt, usually at the expense of lower image quality. +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. - `--guidance_scale`: guidance scale for guidance-distilled models (default: 1.0, disabled). Unlike classifier-free guidance (--cfg_scale), guidance-distilled models take the guidance scale directly as an input parameter. Enabled when guidance_scale > 1. Ignored when not using guidance-distilled models. - `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower). - `--output`: path to save the generated PNG. diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md index 32c655ffd7..d65839dd75 100644 --- a/docs/user_guide/examples/offline_inference/image_to_video.md +++ b/docs/user_guide/examples/offline_inference/image_to_video.md @@ -52,6 +52,7 @@ Key arguments: - `--num_frames`: Number of frames (default 81). - `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high-noise stages for MoE). - `--negative_prompt`: Optional list of artifacts to suppress. +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. - `--boundary_ratio`: Boundary split ratio for two-stage MoE models. - `--flow_shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). - `--num_inference_steps`: Number of denoising steps (default 50). diff --git a/docs/user_guide/examples/offline_inference/text_to_image.md b/docs/user_guide/examples/offline_inference/text_to_image.md index 2dd0f7b7cc..486b9f63b0 100644 --- a/docs/user_guide/examples/offline_inference/text_to_image.md +++ b/docs/user_guide/examples/offline_inference/text_to_image.md @@ -95,6 +95,7 @@ Key arguments: - `--prompt`: text description (string). - `--seed`: integer seed for deterministic sampling. - `--cfg_scale`: true CFG scale (model-specific guidance strength). +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. - `--num_images_per_prompt`: number of images to generate per prompt (saves as `output`, `output_1`, ...). - `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower). - `--height/--width`: output resolution (defaults 1024x1024). diff --git a/docs/user_guide/examples/offline_inference/text_to_video.md b/docs/user_guide/examples/offline_inference/text_to_video.md index 0cc22edea5..db0860b38e 100644 --- a/docs/user_guide/examples/offline_inference/text_to_video.md +++ b/docs/user_guide/examples/offline_inference/text_to_video.md @@ -28,6 +28,7 @@ Key arguments: - `--num_frames`: Number of frames (Wan default is 81). - `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high).. - `--negative_prompt`: optional list of artifacts to suppress (the PR demo used a long Chinese string). +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. - `--boundary_ratio`: Boundary split ratio for low/high DiT. - `--fps`: frames per second for the saved MP4 (requires `diffusers` export_to_video). - `--output`: path to save the generated video. From 8dd8e61fae973e2f4dc6610cbd4614dc1352fefd Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:53:13 +0800 Subject: [PATCH 54/67] offline script example in doc Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/parallelism_acceleration.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 0d17b7e0da..f3b2436007 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -383,19 +383,21 @@ An example of offline inference using CFG-Parallel (image-to-image) is shown bel from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig +image_path = "path_to_image.png" omni = Omni( model="Qwen/Qwen-Image-Edit", parallel_config=DiffusionParallelConfig(cfg_parallel_size=2), ) +input_image = Image.open(image_path).convert("RGB") outputs = omni.generate( { "prompt": "turn this cat to a dog", "negative_prompt": "low quality, blurry", + "multi_modal_data": {"image": input_image}, }, OmniDiffusionSamplingParams( true_cfg_scale=4.0, - pil_image=input_image, num_inference_steps=50, ), ) @@ -405,6 +407,19 @@ Notes: - CFG-Parallel is only effective when a `negative_prompt` is provided AND a guidance scale (or `cfg_scale`) is greater than 1. +See `examples/offline_inference/image_to_image/image_edit.py` for a complete working example. +```bash +cd examples/offline_inference/image_to_image/ +python image_edit.py \ + --model "Qwen/Qwen-Image-Edit" \ + --image "qwen_image_output.png" \ + --prompt "turn this cat to a dog" \ + --negative_prompt "low quality, blurry" \ + --cfg_scale 4.0 \ + --output "edited_image.png" \ + --cfg_parallel_size 2 +``` + #### How to parallelize a pipeline This section describes how to add CFG-Parallel to a diffusion **pipeline**. We use the Qwen-Image pipeline (`vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py`) as the reference implementation. From 18ce884b7cfe5fca98c56da5076c605fae7c5836 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:55:47 +0800 Subject: [PATCH 55/67] online serving args Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/entrypoints/cli/serve.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 5613bdaeb5..b4d7772aaf 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -162,6 +162,15 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu "Equivalent to setting DiffusionParallelConfig.ring_degree.", ) + # CFG Parallel parameters + omni_config_group.add_argument( + "--cfg-parallel-size", + type=int, + default=1, + choices=[1, 2], + help="Number of devices for CFG parallel computation", + ) + # Cache optimization parameters omni_config_group.add_argument( "--cache-backend", From f55beb09534d02962725bf1fb619ece87bfc5ab7 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 16:52:53 +0800 Subject: [PATCH 56/67] serve args Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/entrypoints/cli/serve.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index b4d7772aaf..083114ee40 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -162,15 +162,6 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu "Equivalent to setting DiffusionParallelConfig.ring_degree.", ) - # CFG Parallel parameters - omni_config_group.add_argument( - "--cfg-parallel-size", - type=int, - default=1, - choices=[1, 2], - help="Number of devices for CFG parallel computation", - ) - # Cache optimization parameters omni_config_group.add_argument( "--cache-backend", @@ -234,7 +225,12 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Scheduler flow_shift for video models (e.g., 5.0 for 720p, 12.0 for 480p).", ) omni_config_group.add_argument( - "--cfg-parallel-size", type=int, default=1, help="Number of GPUs for CFG parallel computation" + "--cfg-parallel-size", + type=int, + default=1, + choices=[1, 2], + help="Number of devices for CFG parallel computation for diffusion models. " + "Equivalent to setting DiffusionParallelConfig.cfg_parallel_size.", ) return serve_parser From 3162aac01f2b5e775ead67d7b3a451221eb89aae Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 16:54:38 +0800 Subject: [PATCH 57/67] update doc Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- docs/user_guide/diffusion/parallelism_acceleration.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index f3b2436007..748b679c69 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -373,7 +373,7 @@ def forward(self, hidden_states, ...): ### CFG-Parallel -##### Offline Inference +#### Offline Inference CFG-Parallel is enabled through `DiffusionParallelConfig(cfg_parallel_size=2)`, which runs one rank for the positive branch and one rank for the negative branch. @@ -420,6 +420,14 @@ python image_edit.py \ --cfg_parallel_size 2 ``` +#### Online Serving + +You can enable CFG-Parallel in online serving for diffusion models via `--cfg-parallel-size`: + +```bash +vllm serve Qwen/Qwen-Image-Edit --omni --port 8091 --cfg-parallel-size 2 +``` + #### How to parallelize a pipeline This section describes how to add CFG-Parallel to a diffusion **pipeline**. We use the Qwen-Image pipeline (`vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py`) as the reference implementation. From da2b3072c77d06540df6077157d331337485638f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:29:27 +0800 Subject: [PATCH 58/67] fix error Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/models/qwen_image/pipeline_qwen_image_layered.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index d712d83a34..c914b5fd83 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -189,6 +189,7 @@ def retrieve_latents( class QwenImageLayeredPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): color_format = "RGBA" + def __init__( self, *, From 5189be6fec1f0b985255dfc767357aaf76bbce60 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 09:46:06 +0800 Subject: [PATCH 59/67] remove no_grad Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 3ffd292ca3..3a5be60070 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -645,7 +645,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - @torch.no_grad() def forward( self, req: OmniDiffusionRequest, From c91a3a03c7239d0120bed0927a41e5c82cd0dd68 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 09:50:15 +0800 Subject: [PATCH 60/67] remove torch.save & torch.load Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../distributed/test_cfg_parallel.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py index 2096c48d93..90003f52fe 100644 --- a/tests/diffusion/distributed/test_cfg_parallel.py +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -149,6 +149,7 @@ def _test_cfg_parallel_worker( cfg_parallel_size: int, dtype: torch.dtype, test_config: dict, + result_queue: torch.multiprocessing.Queue, ): """Worker function for CFG parallel test.""" device = torch.device(f"{device_type}:{local_rank}") @@ -219,8 +220,7 @@ def _test_cfg_parallel_worker( # Only rank 0 has valid output in CFG parallel mode if cfg_rank == 0: assert noise_pred is not None - result_path = test_config["result_path"] - torch.save(noise_pred.cpu(), result_path) + result_queue.put(noise_pred.cpu()) else: assert noise_pred is None @@ -232,6 +232,7 @@ def _test_cfg_sequential_worker( world_size: int, dtype: torch.dtype, test_config: dict, + result_queue: torch.multiprocessing.Queue, ): """Worker function for sequential CFG test (baseline).""" device = torch.device(f"{device_type}:{local_rank}") @@ -297,8 +298,7 @@ def _test_cfg_sequential_worker( # Sequential CFG always returns output assert noise_pred is not None - result_path = test_config["baseline_path"] - torch.save(noise_pred.cpu(), result_path) + result_queue.put(noise_pred.cpu()) destroy_distributed_env() @@ -307,9 +307,7 @@ def _test_cfg_sequential_worker( @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("cfg_normalize", [False, True]) -def test_predict_noise_maybe_with_cfg( - cfg_parallel_size: int, dtype: torch.dtype, batch_size: int, cfg_normalize: bool, tmp_path -): +def test_predict_noise_maybe_with_cfg(cfg_parallel_size: int, dtype: torch.dtype, batch_size: int, cfg_normalize: bool): """ Test that predict_noise_maybe_with_cfg produces identical results with and without CFG parallel. @@ -319,7 +317,6 @@ def test_predict_noise_maybe_with_cfg( dtype: Data type for computation batch_size: Batch size for testing cfg_normalize: Whether to normalize CFG output - tmp_path: Temporary directory for storing results """ available_gpus = torch_device.device_count() if available_gpus < cfg_parallel_size: @@ -335,27 +332,29 @@ def test_predict_noise_maybe_with_cfg( "cfg_normalize": cfg_normalize, "model_seed": 42, # Fixed seed for model initialization "input_seed": 123, # Fixed seed for input generation - "baseline_path": str(tmp_path / "baseline.pt"), - "result_path": str(tmp_path / "cfg_parallel.pt"), } + # Create queues for receiving results + baseline_queue = torch.multiprocessing.Queue() + cfg_parallel_queue = torch.multiprocessing.Queue() + # Run baseline (sequential CFG) on single GPU torch.multiprocessing.spawn( _test_cfg_sequential_worker, - args=(1, dtype, test_config), + args=(1, dtype, test_config, baseline_queue), nprocs=1, ) # Run CFG parallel on multiple GPUs torch.multiprocessing.spawn( _test_cfg_parallel_worker, - args=(cfg_parallel_size, cfg_parallel_size, dtype, test_config), + args=(cfg_parallel_size, cfg_parallel_size, dtype, test_config, cfg_parallel_queue), nprocs=cfg_parallel_size, ) - # Load and compare results - baseline_output = torch.load(test_config["baseline_path"]) - cfg_parallel_output = torch.load(test_config["result_path"]) + # Get results from queues + baseline_output = baseline_queue.get() + cfg_parallel_output = cfg_parallel_queue.get() # Verify shapes match assert baseline_output.shape == cfg_parallel_output.shape, ( From 1fdde86d9931f7619d4309f27fb1b6239776a965 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:01:23 +0800 Subject: [PATCH 61/67] update hardward devices Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../distributed/test_cfg_parallel.py | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py index 90003f52fe..7b997f112d 100644 --- a/tests/diffusion/distributed/test_cfg_parallel.py +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -10,7 +10,6 @@ import pytest import torch -from vllm_omni.utils.platform_utils import detect_device_type from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.parallel_state import ( @@ -20,14 +19,7 @@ init_distributed_environment, initialize_model_parallel, ) - -device_type = detect_device_type() -if device_type == "cuda": - torch_device = torch.cuda -elif device_type == "npu": - torch_device = torch.npu -else: - raise ValueError(f"Unsupported device type: {device_type} for this test script! Expected GPU or NPU.") +from vllm_omni.platforms import current_omni_platform def update_environment_variables(envs_dict: dict[str, str]): @@ -152,8 +144,8 @@ def _test_cfg_parallel_worker( result_queue: torch.multiprocessing.Queue, ): """Worker function for CFG parallel test.""" - device = torch.device(f"{device_type}:{local_rank}") - torch_device.set_device(device) + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) update_environment_variables( { @@ -235,8 +227,8 @@ def _test_cfg_sequential_worker( result_queue: torch.multiprocessing.Queue, ): """Worker function for sequential CFG test (baseline).""" - device = torch.device(f"{device_type}:{local_rank}") - torch_device.set_device(device) + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) update_environment_variables( { @@ -318,7 +310,7 @@ def test_predict_noise_maybe_with_cfg(cfg_parallel_size: int, dtype: torch.dtype batch_size: Batch size for testing cfg_normalize: Whether to normalize CFG output """ - available_gpus = torch_device.device_count() + available_gpus = current_omni_platform.get_device_count() if available_gpus < cfg_parallel_size: pytest.skip(f"Test requires {cfg_parallel_size} GPUs but only {available_gpus} available") @@ -395,12 +387,12 @@ def test_predict_noise_without_cfg(dtype: torch.dtype): When CFG is disabled, only the positive branch should be computed. This test runs on a single GPU without distributed environment. """ - available_gpus = torch_device.device_count() + available_gpus = current_omni_platform.get_device_count() if available_gpus < 1: pytest.skip("Test requires at least 1 GPU") - device = torch.device(f"{device_type}:0") - torch_device.set_device(device) + device = torch.device(f"{current_omni_platform.device_type}:0") + current_omni_platform.set_device(device) # Create pipeline without distributed environment pipeline = TestCFGPipeline(in_channels=4, hidden_dim=128, seed=42) From 117e0de9f760a24a2c7da24a8b313916c8062044 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:08:12 +0800 Subject: [PATCH 62/67] mv QwenImageCFGParallelMixin in qwen_image folder Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/parallelism_acceleration.md | 35 ++--- .../diffusion/distributed/cfg_parallel.py | 106 --------------- .../diffusion/models/qwen_image/__init__.py | 4 + .../models/qwen_image/cfg_parallel.py | 121 ++++++++++++++++++ .../models/qwen_image/pipeline_qwen_image.py | 4 +- .../qwen_image/pipeline_qwen_image_edit.py | 4 +- .../pipeline_qwen_image_edit_plus.py | 4 +- .../qwen_image/pipeline_qwen_image_layered.py | 4 +- 8 files changed, 157 insertions(+), 125 deletions(-) create mode 100644 vllm_omni/diffusion/models/qwen_image/cfg_parallel.py diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 748b679c69..7334df1d49 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -449,23 +449,28 @@ vLLM-omni provides `CFGParallelMixin` base class that encapsulates the CFG paral ```python class QwenImageCFGParallelMixin(CFGParallelMixin): + """ + Base Mixin class for Qwen Image pipelines providing shared CFG methods. + """ + def diffuse( self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - cfg_normalize=True, - ... - ): + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + latents: torch.Tensor, + img_shapes: torch.Tensor, + txt_seq_lens: torch.Tensor, + negative_txt_seq_lens: torch.Tensor, + timesteps: torch.Tensor, + do_true_cfg: bool, + guidance: torch.Tensor, + true_cfg_scale: float, + image_latents: torch.Tensor | None = None, + cfg_normalize: bool = True, + additional_transformer_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: self.transformer.do_true_cfg = do_true_cfg for i, t in enumerate(timesteps): diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 27898610f8..ae1c64d17b 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -243,109 +243,3 @@ def scheduler_step_maybe_with_cfg( latents = self.scheduler_step(noise_pred, t, latents) return latents - - -class QwenImageCFGParallelMixin(CFGParallelMixin): - """ - Base Mixin class for Qwen Image pipelines providing shared CFG methods. - """ - - def diffuse( - self, - prompt_embeds: torch.Tensor, - prompt_embeds_mask: torch.Tensor, - negative_prompt_embeds: torch.Tensor, - negative_prompt_embeds_mask: torch.Tensor, - latents: torch.Tensor, - img_shapes: torch.Tensor, - txt_seq_lens: torch.Tensor, - negative_txt_seq_lens: torch.Tensor, - timesteps: torch.Tensor, - do_true_cfg: bool, - guidance: torch.Tensor, - true_cfg_scale: float, - image_latents: torch.Tensor | None = None, - cfg_normalize: bool = True, - additional_transformer_kwargs: dict[str, Any] | None = None, - ) -> torch.Tensor: - """ - Diffusion loop with optional classifier-free guidance. - - Args: - prompt_embeds: Positive prompt embeddings - prompt_embeds_mask: Mask for positive prompt - negative_prompt_embeds: Negative prompt embeddings - negative_prompt_embeds_mask: Mask for negative prompt - latents: Noise latents to denoise - img_shapes: Image shape information - txt_seq_lens: Text sequence lengths for positive prompts - negative_txt_seq_lens: Text sequence lengths for negative prompts - timesteps: Diffusion timesteps - do_true_cfg: Whether to apply CFG - guidance: Guidance scale tensor - true_cfg_scale: CFG scale factor - image_latents: Conditional image latents for editing (default: None) - cfg_normalize: Whether to normalize CFG output (default: True) - additional_transformer_kwargs: Extra kwargs to pass to transformer (default: None) - - Returns: - Denoised latents - """ - self.scheduler.set_begin_index(0) - self.transformer.do_true_cfg = do_true_cfg - additional_transformer_kwargs = additional_transformer_kwargs or {} - - for i, t in enumerate(timesteps): - if self.interrupt: - continue - self._current_timestep = t - - # Broadcast timestep to match batch size - timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - - # Concatenate image latents with noise latents if available (for editing pipelines) - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep / 1000, - "guidance": guidance, - "encoder_hidden_states_mask": prompt_embeds_mask, - "encoder_hidden_states": prompt_embeds, - "img_shapes": img_shapes, - "txt_seq_lens": txt_seq_lens, - **additional_transformer_kwargs, - } - if do_true_cfg: - negative_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep / 1000, - "guidance": guidance, - "encoder_hidden_states_mask": negative_prompt_embeds_mask, - "encoder_hidden_states": negative_prompt_embeds, - "img_shapes": img_shapes, - "txt_seq_lens": negative_txt_seq_lens, - **additional_transformer_kwargs, - } - else: - negative_kwargs = None - - # For editing pipelines, we need to slice the output to remove condition latents - output_slice = latents.size(1) if image_latents is not None else None - - # Predict noise with automatic CFG parallel handling - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg, - true_cfg_scale, - positive_kwargs, - negative_kwargs, - cfg_normalize, - output_slice, - ) - - # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - - return latents diff --git a/vllm_omni/diffusion/models/qwen_image/__init__.py b/vllm_omni/diffusion/models/qwen_image/__init__.py index 84fa2259d4..4b823ec75d 100644 --- a/vllm_omni/diffusion/models/qwen_image/__init__.py +++ b/vllm_omni/diffusion/models/qwen_image/__init__.py @@ -2,6 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Qwen Image diffusion model components.""" +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import ( QwenImagePipeline, get_qwen_image_post_process_func, @@ -11,6 +14,7 @@ ) __all__ = [ + "QwenImageCFGParallelMixin", "QwenImagePipeline", "QwenImageTransformer2DModel", "get_qwen_image_post_process_func", diff --git a/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py b/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py new file mode 100644 index 0000000000..4ab19d840a --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""CFG Parallel Mixin for Qwen Image series +Shared by +- QwenImagePipeline +- QwenImageEditPipeline +- QwenImageEditPlusPipeline +- QwenImageLayeredPipeline +""" + +from typing import Any + +import torch + +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin + + +class QwenImageCFGParallelMixin(CFGParallelMixin): + """ + Base Mixin class for Qwen Image pipelines providing shared CFG methods. + """ + + def diffuse( + self, + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + latents: torch.Tensor, + img_shapes: torch.Tensor, + txt_seq_lens: torch.Tensor, + negative_txt_seq_lens: torch.Tensor, + timesteps: torch.Tensor, + do_true_cfg: bool, + guidance: torch.Tensor, + true_cfg_scale: float, + image_latents: torch.Tensor | None = None, + cfg_normalize: bool = True, + additional_transformer_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + """ + Diffusion loop with optional classifier-free guidance. + + Args: + prompt_embeds: Positive prompt embeddings + prompt_embeds_mask: Mask for positive prompt + negative_prompt_embeds: Negative prompt embeddings + negative_prompt_embeds_mask: Mask for negative prompt + latents: Noise latents to denoise + img_shapes: Image shape information + txt_seq_lens: Text sequence lengths for positive prompts + negative_txt_seq_lens: Text sequence lengths for negative prompts + timesteps: Diffusion timesteps + do_true_cfg: Whether to apply CFG + guidance: Guidance scale tensor + true_cfg_scale: CFG scale factor + image_latents: Conditional image latents for editing (default: None) + cfg_normalize: Whether to normalize CFG output (default: True) + additional_transformer_kwargs: Extra kwargs to pass to transformer (default: None) + + Returns: + Denoised latents + """ + self.scheduler.set_begin_index(0) + self.transformer.do_true_cfg = do_true_cfg + additional_transformer_kwargs = additional_transformer_kwargs or {} + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + self._current_timestep = t + + # Broadcast timestep to match batch size + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + # Concatenate image latents with noise latents if available (for editing pipelines) + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": prompt_embeds_mask, + "encoder_hidden_states": prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + **additional_transformer_kwargs, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + **additional_transformer_kwargs, + } + else: + negative_kwargs = None + + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + true_cfg_scale, + positive_kwargs, + negative_kwargs, + cfg_normalize, + output_slice, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + return latents diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 328bc53920..5bffcde75d 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -25,9 +25,11 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import QwenImageCFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 1941748810..29ba72bc63 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -25,10 +25,12 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import QwenImageCFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 6661d3fd54..14a187b4bb 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -23,10 +23,12 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import QwenImageCFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import ( calculate_dimensions, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index c914b5fd83..f3fdedc5d2 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -22,13 +22,15 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import QwenImageCFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import ( AutoencoderKLQwenImage, ) +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) From 027e717614c598f404189e9e882e8fa7f081ba96 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:27:48 +0800 Subject: [PATCH 63/67] check cfg_parallel validity in pipelines Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../diffusion/models/flux/pipeline_flux.py | 18 +++++++++ .../models/qwen_image/cfg_parallel.py | 39 +++++++++++++++++++ .../models/qwen_image/pipeline_qwen_image.py | 2 + .../qwen_image/pipeline_qwen_image_edit.py | 2 + .../pipeline_qwen_image_edit_plus.py | 2 + .../qwen_image/pipeline_qwen_image_layered.py | 2 + 6 files changed, 65 insertions(+) diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index a91601b7ec..b90aaa8ca4 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -22,6 +22,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux import FluxTransformer2DModel @@ -554,6 +555,21 @@ def diffuse( latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] return latents + def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool): + if get_classifier_free_guidance_world_size() == 1: + return True + + if true_cfg_scale <= 1: + logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.") + return False + + if not has_neg_prompt: + logger.warning( + "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings." + ) + return False + return True + def forward( self, req: OmniDiffusionRequest, @@ -637,6 +653,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + ( prompt_embeds, pooled_prompt_embeds, diff --git a/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py b/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py index 4ab19d840a..9a882f7bf0 100644 --- a/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py +++ b/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py @@ -8,11 +8,15 @@ - QwenImageLayeredPipeline """ +import logging from typing import Any import torch from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size + +logger = logging.getLogger(__name__) class QwenImageCFGParallelMixin(CFGParallelMixin): @@ -119,3 +123,38 @@ def diffuse( latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) return latents + + def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool): + """ + Validate whether CFG parallel is properly configured for the current generation request. + + When CFG parallel is enabled (cfg_parallel_world_size > 1), this method verifies that the necessary + conditions are met for correct parallel execution. If validation fails, a warning is + logged to help identify configuration issues. + + Args: + true_cfg_scale: The classifier-free guidance scale value. Must be > 1 for CFG to + have an effect. + has_neg_prompt: Whether negative prompts or negative prompt embeddings are provided. + Required for CFG to perform unconditional prediction. + + Returns: + True if CFG parallel is disabled or all validation checks pass, False otherwise. + + Note: + When CFG parallel is disabled (world_size == 1), this method always returns True + as no parallel-specific validation is needed. + """ + if get_classifier_free_guidance_world_size() == 1: + return True + + if true_cfg_scale <= 1: + logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.") + return False + + if not has_neg_prompt: + logger.warning( + "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings." + ) + return False + return True diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 5bffcde75d..d85d98b5bf 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -613,6 +613,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 29ba72bc63..78fd92c9d5 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -706,6 +706,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt=prompt, image=prompt_image, # Use resized image for prompt encoding diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 14a187b4bb..00e7758029 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -661,6 +661,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt=prompt, image=condition_images, # Use condition images for prompt encoding diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index f3fdedc5d2..d200764ebf 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -706,6 +706,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, From 35794701710fd39e2a3c5c3d6c3ef2e0cb520ff9 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:38:30 +0800 Subject: [PATCH 64/67] fix unit test spawn process error Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- tests/diffusion/distributed/test_cfg_parallel.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py index 7b997f112d..24e4559de3 100644 --- a/tests/diffusion/distributed/test_cfg_parallel.py +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -326,9 +326,11 @@ def test_predict_noise_maybe_with_cfg(cfg_parallel_size: int, dtype: torch.dtype "input_seed": 123, # Fixed seed for input generation } - # Create queues for receiving results - baseline_queue = torch.multiprocessing.Queue() - cfg_parallel_queue = torch.multiprocessing.Queue() + mp_context = torch.multiprocessing.get_context("spawn") + + manager = mp_context.Manager() + baseline_queue = manager.Queue() + cfg_parallel_queue = manager.Queue() # Run baseline (sequential CFG) on single GPU torch.multiprocessing.spawn( From 6a3070bd449f25a516d4d1c030b55d7f95bb4bf8 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 15:11:44 +0800 Subject: [PATCH 65/67] rm mps related code Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- .../image_to_video/image_to_video.py | 10 ++++++++++ .../offline_inference/text_to_video/text_to_video.py | 4 +++- .../models/flux2_klein/pipeline_flux2_klein.py | 12 ------------ .../models/longcat_image/pipeline_longcat_image.py | 12 ------------ .../longcat_image/pipeline_longcat_image_edit.py | 12 ------------ .../models/ovis_image/pipeline_ovis_image.py | 12 ------------ 6 files changed, 13 insertions(+), 49 deletions(-) diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 646e002bfd..8e8d399155 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -153,6 +153,16 @@ def main(): print("[Profiler] Starting profiling...") omni.start_profile() + # Print generation configuration + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {args.num_inference_steps}") + print(f" Frames: {args.num_frames}") + print(f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}") + print(f" Video size: {args.width}x{args.height}") + print(f"{'=' * 60}\n") + # omni.generate() returns Generator[OmniRequestOutput, None, None] frames = omni.generate( { diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 8d14e93c2c..e9dd2d0856 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -172,7 +172,9 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Frames: {args.num_frames}") - print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") + print( + f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + ) print(f" Video size: {args.width}x{args.height}") print(f"{'=' * 60}\n") diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 3a5be60070..e1ef706c3f 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -991,18 +991,6 @@ def forward( return DiffusionOutput(output=image) - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - return latents - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 3b75c36abf..09f409f313 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -672,18 +672,6 @@ def forward( return DiffusionOutput(output=image) - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - return latents - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index be26de538f..a34a2cca39 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -701,18 +701,6 @@ def forward( image = self.vae.decode(latents, return_dict=False)[0] return DiffusionOutput(output=image) - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - return latents - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index db87386804..963f1c483b 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -504,18 +504,6 @@ def diffuse( return latents - def scheduler_step(self, noise_pred, t, latents): - """ - Step the scheduler. - """ - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - return latents - @property def guidance_scale(self): return self._guidance_scale From a939170e7dd7398e716b1e2e5e4d5cf82e984756 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 16:01:35 +0800 Subject: [PATCH 66/67] mv empty_cache to wan pipelines after all diffusion steps Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/distributed/cfg_parallel.py | 10 ---------- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 2 ++ .../diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 2 ++ .../diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py | 2 ++ 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index ae1c64d17b..9f86bce228 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -68,22 +68,12 @@ def predict_noise_maybe_with_cfg( gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - del local_pred - if cfg_rank == 0: noise_pred = gathered[0] neg_noise_pred = gathered[1] noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) - - del gathered, neg_noise_pred - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return noise_pred else: - del gathered - if torch.cuda.is_available(): - torch.cuda.empty_cache() return None else: # Sequential CFG: compute both positive and negative diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 8979a47f88..f8de284fa0 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -620,6 +620,8 @@ def forward( # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + if torch.cuda.is_available(): + torch.cuda.empty_cache() self._current_timestep = None # For I2V mode: blend final latents with condition diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 1e166d3482..facf472c0c 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -521,6 +521,8 @@ def forward( # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + if torch.cuda.is_available(): + torch.cuda.empty_cache() self._current_timestep = None # For expand_timesteps mode, blend final latents with condition diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 1a2b94f8a3..4ebd3ecbc1 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -434,6 +434,8 @@ def forward( # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + if torch.cuda.is_available(): + torch.cuda.empty_cache() self._current_timestep = None # For I2V mode, blend final latents with condition From 5a9af70b7a8928428c971134c0bffe4451d98d66 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:09:33 +0800 Subject: [PATCH 67/67] omni_platform and comment Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 7 +++++-- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 7 +++++-- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py | 7 +++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index f8de284fa0..b902bc692e 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -25,6 +25,7 @@ from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -620,8 +621,10 @@ def forward( # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() self._current_timestep = None # For I2V mode: blend final latents with condition diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index facf472c0c..1aed9b75de 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -30,6 +30,7 @@ ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -521,8 +522,10 @@ def forward( # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() self._current_timestep = None # For expand_timesteps mode, blend final latents with condition diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 4ebd3ecbc1..d32b7d697c 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -43,6 +43,7 @@ ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -434,8 +435,10 @@ def forward( # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() self._current_timestep = None # For I2V mode, blend final latents with condition