diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 1a9f540ad3..7334df1d49 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -373,9 +373,9 @@ def forward(self, hidden_states, ...): ### CFG-Parallel -##### Offline Inference +#### Offline Inference -CFG-Parallel is enabled through `DiffusionParallelConfig(cfg_parallel_size=...)`. The recommended configuration is `cfg_parallel_size=2` (one rank for the positive branch and one rank for the negative branch). +CFG-Parallel is enabled through `DiffusionParallelConfig(cfg_parallel_size=2)`, which runs one rank for the positive branch and one rank for the negative branch. An example of offline inference using CFG-Parallel (image-to-image) is shown below: @@ -383,19 +383,21 @@ An example of offline inference using CFG-Parallel (image-to-image) is shown bel from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig +image_path = "path_to_image.png" omni = Omni( model="Qwen/Qwen-Image-Edit", parallel_config=DiffusionParallelConfig(cfg_parallel_size=2), ) +input_image = Image.open(image_path).convert("RGB") outputs = omni.generate( { "prompt": "turn this cat to a dog", "negative_prompt": "low quality, blurry", + "multi_modal_data": {"image": input_image}, }, OmniDiffusionSamplingParams( true_cfg_scale=4.0, - pil_image=input_image, num_inference_steps=50, ), ) @@ -403,7 +405,28 @@ outputs = omni.generate( Notes: -- CFG-Parallel is only effective when **true CFG** is enabled (i.e., `true_cfg_scale > 1` and a `negative_prompt` is provided). +- CFG-Parallel is only effective when a `negative_prompt` is provided AND a guidance scale (or `cfg_scale`) is greater than 1. + +See `examples/offline_inference/image_to_image/image_edit.py` for a complete working example. +```bash +cd examples/offline_inference/image_to_image/ +python image_edit.py \ + --model "Qwen/Qwen-Image-Edit" \ + --image "qwen_image_output.png" \ + --prompt "turn this cat to a dog" \ + --negative_prompt "low quality, blurry" \ + --cfg_scale 4.0 \ + --output "edited_image.png" \ + --cfg_parallel_size 2 +``` + +#### Online Serving + +You can enable CFG-Parallel in online serving for diffusion models via `--cfg-parallel-size`: + +```bash +vllm serve Qwen/Qwen-Image-Edit --omni --port 8091 --cfg-parallel-size 2 +``` #### How to parallelize a pipeline @@ -416,58 +439,130 @@ In `QwenImagePipeline`, each diffusion step runs two denoiser forward passes seq CFG-Parallel assigns these two branches to different ranks in the **CFG group** and synchronizes the results. -Below is an example of CFG-Parallel implementation: +vLLM-omni provides `CFGParallelMixin` base class that encapsulates the CFG parallel logic. By inheriting from this mixin and calling its methods, pipelines can easily implement CFG parallel without writing repetitive code. + +**Key Methods in CFGParallelMixin:** +- `predict_noise_maybe_with_cfg()`: Automatically handles CFG parallel noise prediction +- `scheduler_step_maybe_with_cfg()`: Scheduler step with automatic CFG rank synchronization + +**Example Implementation:** ```python -def diffuse( +class QwenImageCFGParallelMixin(CFGParallelMixin): + """ + Base Mixin class for Qwen Image pipelines providing shared CFG methods. + """ + + def diffuse( self, - ... - ): - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - self.transformer.do_true_cfg = do_true_cfg - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - else: - local_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - else: - # fallback: run positive then negative sequentially on one rank - ... + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + latents: torch.Tensor, + img_shapes: torch.Tensor, + txt_seq_lens: torch.Tensor, + negative_txt_seq_lens: torch.Tensor, + timesteps: torch.Tensor, + do_true_cfg: bool, + guidance: torch.Tensor, + true_cfg_scale: float, + image_latents: torch.Tensor | None = None, + cfg_normalize: bool = True, + additional_transformer_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + self.transformer.do_true_cfg = do_true_cfg + + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + # Prepare kwargs for positive (conditional) prediction + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": prompt_embeds_mask, + "encoder_hidden_states": prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + # Prepare kwargs for negative (unconditional) prediction + if do_true_cfg: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + # - In CFG parallel mode: rank0 computes positive, rank1 computes negative + # - Automatically gathers results and combines them on rank0 + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=true_cfg_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=cfg_normalize, + ) + + # Step scheduler with automatic CFG synchronization + # - Only rank0 computes the scheduler step + # - Automatically broadcasts updated latents to all ranks + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg + ) + + return latents +``` + +**How it works:** +1. Prepare separate `positive_kwargs` and `negative_kwargs` for conditional and unconditional predictions +2. Call `predict_noise_maybe_with_cfg()` which: + - Detects if CFG parallel is enabled (`get_classifier_free_guidance_world_size() > 1`) + - Distributes computation: rank0 processes positive, rank1 processes negative + - Gathers predictions and combines them using `combine_cfg_noise()` on rank0 + - Returns combined noise prediction (only valid on rank0) +3. Call `scheduler_step_maybe_with_cfg()` which: + - Only rank0 computes the scheduler step + - Broadcasts the updated latents to all ranks for synchronization + +**How to customize** + +Some pipelines may need to customize the following functions in `CFGParallelMixin`: +1. You may need to edit `predict_noise` function for custom behaviors. +```python +def predict_noise(self, *args, **kwargs): + """ + Forward pass through transformer to predict noise. + + Subclasses should override this if they need custom behavior, + but the default implementation calls self.transformer. + """ + return self.transformer(*args, **kwargs)[0] + +``` +2. The default normalization function after combining the noise predictions from both branches is as follows. You may need to customize it. +```python +def cfg_normalize_function(self, noise_pred, comb_pred): + """ + Normalize the combined noise prediction. + + Args: + noise_pred: positive noise prediction + comb_pred: combined noise prediction after CFG + + Returns: + Normalized noise prediction tensor + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + return noise_pred ``` diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 8d31747d21..d081243782 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -39,23 +39,23 @@ The following table shows which models are currently supported by each accelerat | Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | |-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ❌ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | | **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | -| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ❌ | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | | **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ### VideoGen -| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | -|-------|------------------|:--------:|:---------:|:----------:|:--------------:| -| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ❌ | ❌ | +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel | +|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:| +| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | ## Performance Benchmarks diff --git a/docs/user_guide/examples/offline_inference/image_to_image.md b/docs/user_guide/examples/offline_inference/image_to_image.md index 78c7f11fcb..c970106d2b 100644 --- a/docs/user_guide/examples/offline_inference/image_to_image.md +++ b/docs/user_guide/examples/offline_inference/image_to_image.md @@ -47,6 +47,7 @@ Key arguments: - `--image`: path(s) to the source image(s) (PNG/JPG, converted to RGB). Can specify multiple images. - `--prompt` / `--negative_prompt`: text description (string). - `--cfg_scale`: true classifier-free guidance scale (default: 4.0). Classifier-free guidance is enabled by setting cfg_scale > 1 and providing a negative_prompt. Higher guidance scale encourages images closely linked to the text prompt, usually at the expense of lower image quality. +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. - `--guidance_scale`: guidance scale for guidance-distilled models (default: 1.0, disabled). Unlike classifier-free guidance (--cfg_scale), guidance-distilled models take the guidance scale directly as an input parameter. Enabled when guidance_scale > 1. Ignored when not using guidance-distilled models. - `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower). - `--output`: path to save the generated PNG. diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md index 32c655ffd7..d65839dd75 100644 --- a/docs/user_guide/examples/offline_inference/image_to_video.md +++ b/docs/user_guide/examples/offline_inference/image_to_video.md @@ -52,6 +52,7 @@ Key arguments: - `--num_frames`: Number of frames (default 81). - `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high-noise stages for MoE). - `--negative_prompt`: Optional list of artifacts to suppress. +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. - `--boundary_ratio`: Boundary split ratio for two-stage MoE models. - `--flow_shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). - `--num_inference_steps`: Number of denoising steps (default 50). diff --git a/docs/user_guide/examples/offline_inference/text_to_image.md b/docs/user_guide/examples/offline_inference/text_to_image.md index 2dd0f7b7cc..486b9f63b0 100644 --- a/docs/user_guide/examples/offline_inference/text_to_image.md +++ b/docs/user_guide/examples/offline_inference/text_to_image.md @@ -95,6 +95,7 @@ Key arguments: - `--prompt`: text description (string). - `--seed`: integer seed for deterministic sampling. - `--cfg_scale`: true CFG scale (model-specific guidance strength). +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. - `--num_images_per_prompt`: number of images to generate per prompt (saves as `output`, `output_1`, ...). - `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower). - `--height/--width`: output resolution (defaults 1024x1024). diff --git a/docs/user_guide/examples/offline_inference/text_to_video.md b/docs/user_guide/examples/offline_inference/text_to_video.md index 0cc22edea5..db0860b38e 100644 --- a/docs/user_guide/examples/offline_inference/text_to_video.md +++ b/docs/user_guide/examples/offline_inference/text_to_video.md @@ -28,6 +28,7 @@ Key arguments: - `--num_frames`: Number of frames (Wan default is 81). - `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high).. - `--negative_prompt`: optional list of artifacts to suppress (the PR demo used a long Chinese string). +- `--cfg_parallel_size`: the number of devices to run CFG Parallel. CFG Parallel is valid only if classifier-free guidance is enabled and `cfg_parallel_size` is set to 2. - `--boundary_ratio`: Boundary split ratio for low/high DiT. - `--fps`: frames per second for the saved MP4 (requires `diffusers` export_to_video). - `--output`: path to save the generated video. diff --git a/examples/offline_inference/image_to_image/image_to_image.md b/examples/offline_inference/image_to_image/image_to_image.md index b9e25bf986..d0986d6ee7 100644 --- a/examples/offline_inference/image_to_image/image_to_image.md +++ b/examples/offline_inference/image_to_image/image_to_image.md @@ -49,6 +49,7 @@ Key arguments: - `--output`: path to save the generated PNG. - `--vae_use_slicing`: enable VAE slicing for memory optimization. - `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. > ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md index 52bb389e73..a1355dab69 100644 --- a/examples/offline_inference/image_to_video/README.md +++ b/examples/offline_inference/image_to_video/README.md @@ -56,6 +56,7 @@ Key arguments: - `--output`: Path to save the generated video. - `--vae_use_slicing`: Enable VAE slicing for memory optimization. - `--vae_use_tiling`: Enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. > ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 1785287849..8e8d399155 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -26,6 +26,7 @@ import PIL.Image import torch +from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -85,6 +86,18 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of ready layers (blocks) to keep on GPU during generation.", ) + parser.add_argument( + "--cfg_parallel_size", + type=int, + default=1, + choices=[1, 2], + help="Number of GPUs used for classifier free guidance parallel size.", + ) + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) return parser.parse_args() @@ -120,7 +133,9 @@ def main(): # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) - + parallel_config = DiffusionParallelConfig( + cfg_parallel_size=args.cfg_parallel_size, + ) omni = Omni( model=args.model, enable_layerwise_offload=args.enable_layerwise_offload, @@ -130,12 +145,24 @@ def main(): boundary_ratio=args.boundary_ratio, flow_shift=args.flow_shift, enable_cpu_offload=args.enable_cpu_offload, + parallel_config=parallel_config, + enforce_eager=args.enforce_eager, ) if profiler_enabled: print("[Profiler] Starting profiling...") omni.start_profile() + # Print generation configuration + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {args.num_inference_steps}") + print(f" Frames: {args.num_frames}") + print(f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}") + print(f" Video size: {args.width}x{args.height}") + print(f"{'=' * 60}\n") + # omni.generate() returns Generator[OmniRequestOutput, None, None] frames = omni.generate( { diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index ab28b115c1..9c57a621cf 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -98,6 +98,7 @@ Key arguments: - `--output`: path to save the generated PNG. - `--vae_use_slicing`: enable VAE slicing for memory optimization. - `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. > ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/text_to_video/text_to_video.md b/examples/offline_inference/text_to_video/text_to_video.md index b5d0b2adc2..04f1a2653b 100644 --- a/examples/offline_inference/text_to_video/text_to_video.md +++ b/examples/offline_inference/text_to_video/text_to_video.md @@ -31,6 +31,7 @@ Key arguments: - `--output`: path to save the generated video. - `--vae_use_slicing`: enable VAE slicing for memory optimization. - `--vae_use_tiling`: enable VAE tiling for memory optimization. +- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. > ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage. diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 63474987fa..e9dd2d0856 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -102,6 +102,14 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ring sequence parallelism.", ) + parser.add_argument( + "--cfg_parallel_size", + type=int, + default=1, + choices=[1, 2], + help="Number of GPUs used for classifier free guidance parallel size.", + ) + return parser.parse_args() @@ -132,6 +140,7 @@ def main(): parallel_config = DiffusionParallelConfig( ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, + cfg_parallel_size=args.cfg_parallel_size, ) # Check if profiling is requested via environment variable @@ -163,7 +172,9 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Frames: {args.num_frames}") - print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}") + print( + f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + ) print(f" Video size: {args.width}x{args.height}") print(f"{'=' * 60}\n") diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py new file mode 100644 index 0000000000..24e4559de3 --- /dev/null +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -0,0 +1,423 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for CFG (Classifier-Free Guidance) parallel functionality. + +This test verifies that predict_noise_maybe_with_cfg produces numerically +equivalent results with and without CFG parallel using fixed random inputs. +""" + +import os + +import pytest +import torch + +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.platforms import current_omni_platform + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables.""" + for k, v in envs_dict.items(): + os.environ[k] = v + + +class SimpleTransformer(torch.nn.Module): + """Simple transformer model for testing with random initialization. + + Contains: + - Input projection (conv to hidden_dim) + - QKV projection layers + - Self-attention layer + - Output projection + """ + + def __init__(self, in_channels: int = 4, hidden_dim: int = 128, num_heads: int = 8): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + + assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" + + # Input projection: (B, C, H, W) -> (B, hidden_dim, H, W) + self.input_proj = torch.nn.Conv2d(in_channels, hidden_dim, 1) + + # QKV projection layers + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim) + + # Output projection after attention + self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim) + + # Final output projection: (B, hidden_dim, H, W) -> (B, C, H, W) + self.final_proj = torch.nn.Conv2d(hidden_dim, in_channels, 1) + + # Layer norm + self.norm1 = torch.nn.LayerNorm(hidden_dim) + self.norm2 = torch.nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: + """Forward pass with self-attention. + + Args: + x: Input tensor of shape (B, C, H, W) + + Returns: + Output tensor of shape (B, C, H, W) + """ + B, C, H, W = x.shape + + # Input projection + x = self.input_proj(x) # (B, hidden_dim, H, W) + + # Reshape to sequence: (B, hidden_dim, H, W) -> (B, H*W, hidden_dim) + x = x.flatten(2).transpose(1, 2) # (B, H*W, hidden_dim) + + # Self-attention with residual connection + residual = x + x = self.norm1(x) + + # QKV projection + q = self.q_proj(x) # (B, H*W, hidden_dim) + k = self.k_proj(x) # (B, H*W, hidden_dim) + v = self.v_proj(x) # (B, H*W, hidden_dim) + + # Reshape for multi-head attention: (B, H*W, hidden_dim) -> (B, num_heads, H*W, head_dim) + seq_len = H * W + q = q.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # Scaled dot-product attention + scale = self.head_dim**-0.5 + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (B, num_heads, H*W, H*W) + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous().view(B, seq_len, self.hidden_dim) + + attn_output = self.out_proj(attn_output) + + x = residual + attn_output + residual = x + x = self.norm2(x) + x = residual + x + x = x.transpose(1, 2).view(B, self.hidden_dim, H, W) + + out = self.final_proj(x) + + return (out,) + + +class TestCFGPipeline(CFGParallelMixin): + """Test pipeline using CFGParallelMixin.""" + + def __init__(self, in_channels: int = 4, hidden_dim: int = 128, seed: int = 42): + # Set seed BEFORE creating transformer to ensure consistent layer initialization + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + self.transformer = SimpleTransformer(in_channels, hidden_dim) + + # Re-initialize all parameters with fixed seed for full reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + for param in self.transformer.parameters(): + torch.nn.init.normal_(param, mean=0.0, std=0.02) + + +def _test_cfg_parallel_worker( + local_rank: int, + world_size: int, + cfg_parallel_size: int, + dtype: torch.dtype, + test_config: dict, + result_queue: torch.multiprocessing.Queue, +): + """Worker function for CFG parallel test.""" + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29502", + } + ) + + init_distributed_environment() + initialize_model_parallel(cfg_parallel_size=cfg_parallel_size) + + cfg_rank = get_classifier_free_guidance_rank() + cfg_world_size = get_classifier_free_guidance_world_size() + + assert cfg_world_size == cfg_parallel_size + + # Create pipeline with same seed to ensure identical model weights across all ranks + # Note: model_seed is set inside TestCFGPipeline.__init__ + pipeline = TestCFGPipeline( + in_channels=test_config["channels"], + hidden_dim=test_config["hidden_dim"], + seed=test_config["model_seed"], + ) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() # Set to eval mode for deterministic behavior + + # Create fixed inputs with explicit seed setting for reproducibility + # Set both CPU and CUDA seeds to ensure identical inputs across all ranks + torch.manual_seed(test_config["input_seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"]) + + batch_size = test_config["batch_size"] + channels = test_config["channels"] + height = test_config["height"] + width = test_config["width"] + + # Positive input + positive_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Negative input with different seed + torch.manual_seed(test_config["input_seed"] + 1) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"] + 1) + negative_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Prepare kwargs for predict_noise_maybe_with_cfg + positive_kwargs = {"x": positive_input} + negative_kwargs = {"x": negative_input} + + with torch.no_grad(): + # Call predict_noise_maybe_with_cfg + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=test_config["cfg_scale"], + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=test_config["cfg_normalize"], + ) + + # Only rank 0 has valid output in CFG parallel mode + if cfg_rank == 0: + assert noise_pred is not None + result_queue.put(noise_pred.cpu()) + else: + assert noise_pred is None + + destroy_distributed_env() + + +def _test_cfg_sequential_worker( + local_rank: int, + world_size: int, + dtype: torch.dtype, + test_config: dict, + result_queue: torch.multiprocessing.Queue, +): + """Worker function for sequential CFG test (baseline).""" + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "29503", + } + ) + + init_distributed_environment() + initialize_model_parallel(cfg_parallel_size=1) # No CFG parallel + + cfg_world_size = get_classifier_free_guidance_world_size() + assert cfg_world_size == 1 + + # Create pipeline with same seed to ensure identical model weights as CFG parallel + # Note: model_seed is set inside TestCFGPipeline.__init__ + pipeline = TestCFGPipeline( + in_channels=test_config["channels"], + hidden_dim=test_config["hidden_dim"], + seed=test_config["model_seed"], + ) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() + + # Create fixed inputs (same seed as CFG parallel to ensure identical inputs) + # Set both CPU and CUDA seeds for full reproducibility + torch.manual_seed(test_config["input_seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"]) + + batch_size = test_config["batch_size"] + channels = test_config["channels"] + height = test_config["height"] + width = test_config["width"] + + # Positive input + positive_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + # Negative input with different seed + torch.manual_seed(test_config["input_seed"] + 1) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(test_config["input_seed"] + 1) + negative_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device) + + positive_kwargs = {"x": positive_input} + negative_kwargs = {"x": negative_input} + + with torch.no_grad(): + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=test_config["cfg_scale"], + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=test_config["cfg_normalize"], + ) + + # Sequential CFG always returns output + assert noise_pred is not None + result_queue.put(noise_pred.cpu()) + + destroy_distributed_env() + + +@pytest.mark.parametrize("cfg_parallel_size", [2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("cfg_normalize", [False, True]) +def test_predict_noise_maybe_with_cfg(cfg_parallel_size: int, dtype: torch.dtype, batch_size: int, cfg_normalize: bool): + """ + Test that predict_noise_maybe_with_cfg produces identical results + with and without CFG parallel. + + Args: + cfg_parallel_size: Number of GPUs for CFG parallel + dtype: Data type for computation + batch_size: Batch size for testing + cfg_normalize: Whether to normalize CFG output + """ + available_gpus = current_omni_platform.get_device_count() + if available_gpus < cfg_parallel_size: + pytest.skip(f"Test requires {cfg_parallel_size} GPUs but only {available_gpus} available") + + test_config = { + "batch_size": batch_size, + "channels": 4, + "height": 16, + "width": 16, + "hidden_dim": 128, + "cfg_scale": 7.5, + "cfg_normalize": cfg_normalize, + "model_seed": 42, # Fixed seed for model initialization + "input_seed": 123, # Fixed seed for input generation + } + + mp_context = torch.multiprocessing.get_context("spawn") + + manager = mp_context.Manager() + baseline_queue = manager.Queue() + cfg_parallel_queue = manager.Queue() + + # Run baseline (sequential CFG) on single GPU + torch.multiprocessing.spawn( + _test_cfg_sequential_worker, + args=(1, dtype, test_config, baseline_queue), + nprocs=1, + ) + + # Run CFG parallel on multiple GPUs + torch.multiprocessing.spawn( + _test_cfg_parallel_worker, + args=(cfg_parallel_size, cfg_parallel_size, dtype, test_config, cfg_parallel_queue), + nprocs=cfg_parallel_size, + ) + + # Get results from queues + baseline_output = baseline_queue.get() + cfg_parallel_output = cfg_parallel_queue.get() + + # Verify shapes match + assert baseline_output.shape == cfg_parallel_output.shape, ( + f"Shape mismatch: baseline {baseline_output.shape} vs CFG parallel {cfg_parallel_output.shape}" + ) + + # Verify numerical equivalence with appropriate tolerances + if dtype == torch.float32: + rtol, atol = 1e-5, 1e-5 + elif dtype == torch.bfloat16: + rtol, atol = 1e-2, 1e-2 + else: + rtol, atol = 1e-3, 1e-3 + + torch.testing.assert_close( + cfg_parallel_output, + baseline_output, + rtol=rtol, + atol=atol, + msg=( + f"CFG parallel output differs from sequential CFG\n" + f" dtype={dtype}, batch_size={batch_size}, cfg_normalize={cfg_normalize}\n" + f" Max diff: {(cfg_parallel_output - baseline_output).abs().max().item():.6e}" + ), + ) + + print( + f"✓ Test passed: cfg_size={cfg_parallel_size}, dtype={dtype}, " + f"batch_size={batch_size}, cfg_normalize={cfg_normalize}" + ) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_predict_noise_without_cfg(dtype: torch.dtype): + """ + Test predict_noise_maybe_with_cfg when do_true_cfg=False. + + When CFG is disabled, only the positive branch should be computed. + This test runs on a single GPU without distributed environment. + """ + available_gpus = current_omni_platform.get_device_count() + if available_gpus < 1: + pytest.skip("Test requires at least 1 GPU") + + device = torch.device(f"{current_omni_platform.device_type}:0") + current_omni_platform.set_device(device) + + # Create pipeline without distributed environment + pipeline = TestCFGPipeline(in_channels=4, hidden_dim=128, seed=42) + pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype) + pipeline.transformer.eval() + + # Set seed for input generation + torch.manual_seed(123) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(123) + positive_input = torch.randn(1, 4, 16, 16, dtype=dtype, device=device) + + with torch.no_grad(): + noise_pred = pipeline.predict_noise_maybe_with_cfg( + do_true_cfg=False, # No CFG + true_cfg_scale=7.5, + positive_kwargs={"x": positive_input}, + negative_kwargs=None, + cfg_normalize=False, + ) + + # Should always return output when do_true_cfg=False + assert noise_pred is not None + assert noise_pred.shape == (1, 4, 16, 16) + + print(f"✓ Test passed: predict_noise without CFG (dtype={dtype})") diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 4248071f01..f884fb6f17 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -55,6 +55,7 @@ def _validate_parallel_config(self) -> Self: assert self.ulysses_degree > 0, "Ulysses degree must be > 0" assert self.ring_degree > 0, "Ring degree must be > 0" assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" + assert self.cfg_parallel_size in [1, 2], f"CFG parallel size must be 1 or 2, but got {self.cfg_parallel_size}" assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, ( "Sequence parallel size must be equal to the product of ulysses degree and ring degree," f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}" diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py new file mode 100644 index 0000000000..9f86bce228 --- /dev/null +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Base pipeline class for Diffusion models with shared CFG functionality. +""" + +from abc import ABCMeta +from typing import Any + +import torch + +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) + + +class CFGParallelMixin(metaclass=ABCMeta): + """ + Base Mixin class for Diffusion pipelines providing shared CFG methods. + + All pipelines should inherit from this class to reuse + classifier-free guidance logic. + """ + + def predict_noise_maybe_with_cfg( + self, + do_true_cfg: bool, + true_cfg_scale: float, + positive_kwargs: dict[str, Any], + negative_kwargs: dict[str, Any] | None, + cfg_normalize: bool = True, + output_slice: int | None = None, + ) -> torch.Tensor | None: + """ + Predict noise with optional classifier-free guidance. + + Args: + do_true_cfg: Whether to apply CFG + true_cfg_scale: CFG scale factor + positive_kwargs: Kwargs for positive/conditional prediction + negative_kwargs: Kwargs for negative/unconditional prediction + cfg_normalize: Whether to normalize CFG output (default: True) + output_slice: If set, slice output to [:, :output_slice] for image editing + + Returns: + Predicted noise tensor (only valid on rank 0 in CFG parallel mode) + """ + if do_true_cfg: + # Automatically detect CFG parallel configuration + cfg_parallel_ready = get_classifier_free_guidance_world_size() > 1 + + if cfg_parallel_ready: + # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + if cfg_rank == 0: + local_pred = self.predict_noise(**positive_kwargs) + else: + local_pred = self.predict_noise(**negative_kwargs) + + # Slice output for image editing pipelines (remove condition latents) + if output_slice is not None: + local_pred = local_pred[:, :output_slice] + + gathered = cfg_group.all_gather(local_pred, separate_tensors=True) + + if cfg_rank == 0: + noise_pred = gathered[0] + neg_noise_pred = gathered[1] + noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) + return noise_pred + else: + return None + else: + # Sequential CFG: compute both positive and negative + positive_noise_pred = self.predict_noise(**positive_kwargs) + negative_noise_pred = self.predict_noise(**negative_kwargs) + + # Slice output for image editing pipelines + if output_slice is not None: + positive_noise_pred = positive_noise_pred[:, :output_slice] + negative_noise_pred = negative_noise_pred[:, :output_slice] + + noise_pred = self.combine_cfg_noise( + positive_noise_pred, negative_noise_pred, true_cfg_scale, cfg_normalize + ) + return noise_pred + else: + # No CFG: only compute positive/conditional prediction + pred = self.predict_noise(**positive_kwargs) + if output_slice is not None: + pred = pred[:, :output_slice] + return pred + + def cfg_normalize_function(self, noise_pred: torch.Tensor, comb_pred: torch.Tensor) -> torch.Tensor: + """ + Normalize the combined noise prediction. + + Args: + noise_pred: positive noise prediction + comb_pred: combined noise prediction after CFG + + Returns: + Normalized noise prediction tensor + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + return noise_pred + + def combine_cfg_noise( + self, noise_pred: torch.Tensor, neg_noise_pred: torch.Tensor, true_cfg_scale: float, cfg_normalize: bool = False + ) -> torch.Tensor: + """ + Combine conditional and unconditional noise predictions with CFG. + + Args: + noise_pred: Conditional noise prediction + neg_noise_pred: Unconditional noise prediction + true_cfg_scale: CFG scale factor + cfg_normalize: Whether to normalize the combined prediction (default: False) + + Returns: + Combined noise prediction tensor + """ + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + if cfg_normalize: + noise_pred = self.cfg_normalize_function(noise_pred, comb_pred) + else: + noise_pred = comb_pred + + return noise_pred + + def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """ + Forward pass through transformer to predict noise. + + Subclasses should override this if they need custom behavior, + but the default implementation calls self.transformer. + """ + return self.transformer(*args, **kwargs)[0] + + def diffuse( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """ + Diffusion loop with optional classifier-free guidance. + + Subclasses MUST implement this method to define the complete + diffusion/denoising loop for their specific model. + + Typical implementation pattern: + ```python + def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds, ...): + for t in timesteps: + # Prepare kwargs for positive and negative predictions + positive_kwargs = {...} + negative_kwargs = {...} + + # Predict noise with automatic CFG handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=self.guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + ) + + # Step scheduler with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg=True + ) + + return latents + ``` + """ + raise NotImplementedError("Subclasses must implement diffuse") + + def scheduler_step(self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """ + Step the scheduler. + + Args: + noise_pred: Predicted noise + t: Current timestep + latents: Current latents + + Returns: + Updated latents after scheduler step + """ + return self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + def scheduler_step_maybe_with_cfg( + self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor, do_true_cfg: bool + ) -> torch.Tensor: + """ + Step the scheduler with (maybe) automatic CFG parallel synchronization. + + In CFG parallel mode, only rank 0 computes the scheduler step, + then broadcasts the result to other ranks. + + Args: + noise_pred: Predicted noise (only valid on rank 0 in CFG parallel) + t: Current timestep + latents: Current latents + do_true_cfg: Whether CFG is enabled + + Returns: + Updated latents (synchronized across all CFG ranks) + """ + # Automatically detect CFG parallel configuration + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + + if cfg_parallel_ready: + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + # Only rank 0 computes the scheduler step + if cfg_rank == 0: + latents = self.scheduler_step(noise_pred, t, latents) + + # Broadcast the updated latents to all ranks + latents = latents.contiguous() + cfg_group.broadcast(latents, src=0) + else: + # No CFG parallel: directly compute scheduler step + latents = self.scheduler_step(noise_pred, t, latents) + + return latents diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index a91601b7ec..b90aaa8ca4 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -22,6 +22,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux import FluxTransformer2DModel @@ -554,6 +555,21 @@ def diffuse( latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] return latents + def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool): + if get_classifier_free_guidance_world_size() == 1: + return True + + if true_cfg_scale <= 1: + logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.") + return False + + if not has_neg_prompt: + logger.warning( + "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings." + ) + return False + return True + def forward( self, req: OmniDiffusionRequest, @@ -637,6 +653,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + ( prompt_embeds, pooled_prompt_embeds, diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 2eac7b2ece..e1ef706c3f 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -36,6 +36,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( @@ -178,7 +179,7 @@ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: return float(mu) -class Flux2KleinPipeline(nn.Module, SupportImageInput): +class Flux2KleinPipeline(nn.Module, CFGParallelMixin, SupportImageInput): """Flux2 klein pipeline for text-to-image generation.""" support_image_input = True @@ -923,38 +924,44 @@ def forward( latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=None, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - noise_pred = noise_pred[:, : latents.size(1) :] - + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": None, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, + } if self.do_classifier_free_guidance: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=None, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype and torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": None, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, + } + else: + negative_kwargs = None + + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=self.do_classifier_free_guidance, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + output_slice=output_slice, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance) if callback_on_step_end is not None: callback_kwargs = {} diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 8b616ec45f..09f409f313 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -8,6 +8,7 @@ import os import re from collections.abc import Iterable +from functools import partial from typing import Any import numpy as np @@ -23,6 +24,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel @@ -197,9 +199,7 @@ def get_prompt_language(prompt): return "en" -class LongCatImagePipeline( - nn.Module, -): +class LongCatImagePipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -460,6 +460,16 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0): + """ + Normalize the combined noise prediction. + """ + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = comb_pred * scale + return noise_pred + def forward( self, req: OmniDiffusionRequest, @@ -604,6 +614,9 @@ def forward( if self.do_classifier_free_guidance: negative_prompt_embeds = negative_prompt_embeds.to(device) + # custom partial function with cfg_renorm_min + self.cfg_normalize_function = partial(self.cfg_normalize_function, cfg_renorm_min=cfg_renorm_min) + # 6. Denoising loop for i, t in enumerate(timesteps): if self._interrupt: @@ -612,43 +625,37 @@ def forward( self._current_timestep = t timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred_text = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) - - if enable_cfg_renorm: - cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) - noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) - noise_pred = noise_pred * scale + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } else: - noise_pred = noise_pred_text - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=self.do_classifier_free_guidance, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=enable_cfg_renorm, + ) - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index f2c3fd648e..a34a2cca39 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -26,6 +26,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -215,7 +216,7 @@ def split_quotation(prompt, quote_pairs=None): return result -class LongCatImageEditPipeline(nn.Module, SupportImageInput): +class LongCatImageEditPipeline(nn.Module, CFGParallelMixin, SupportImageInput): def __init__( self, *, @@ -652,38 +653,39 @@ def forward( latent_model_input = torch.cat([latents, image_latents], dim=1) timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - - noise_pred_text = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred_text = noise_pred_text[:, :image_seq_len] - if guidance_scale > 1: - noise_pred_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - noise_pred_uncond = noise_pred_uncond[:, :image_seq_len] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + do_true_cfg = guidance_scale > 1 + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } else: - noise_pred = noise_pred_text - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + output_slice=image_seq_len, + ) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index a87410a6c9..963f1c483b 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -23,18 +23,19 @@ import numpy as np import torch -import torch.nn as nn from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.torch_utils import randn_tensor +from torch import nn from transformers import Qwen2TokenizerFast, Qwen3Model from vllm.logger import init_logger from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.ovis_image.ovis_image_transformer import OvisImageTransformer2DModel @@ -139,9 +140,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class OvisImagePipeline( - nn.Module, -): +class OvisImagePipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -431,68 +430,77 @@ def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): ) return timesteps, num_inference_steps - def denoising( + def diffuse( self, - latents, - timesteps, - prompt_embeds, - negative_prompt_embeds, - guidance_scale, - do_classifier_free_guidance, - callback_on_step_end, - callback_on_step_end_tensor_inputs, - text_ids, - negative_text_ids, - latent_image_ids, - ): + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + text_ids: torch.Tensor, + negative_text_ids: torch.Tensor, + latent_image_ids: torch.Tensor, + do_true_cfg: bool, + guidance_scale: float, + cfg_normalize: bool = False, + ) -> torch.Tensor: + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + timesteps: Diffusion timesteps + prompt_embeds: Positive prompt embeddings + negative_prompt_embeds: Negative prompt embeddings + text_ids: Position IDs for positive text + negative_text_ids: Position IDs for negative text + latent_image_ids: Position IDs for image latents + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + cfg_normalize: Whether to normalize CFG output (default: False) + + Returns: + Denoised latents + """ self.scheduler.set_begin_index(0) + for i, t in enumerate(timesteps): - if self._interrupt: + if self.interrupt: break self._current_timestep = t - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - - if do_classifier_free_guidance: - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - # Not used in this pipeline - # if callback_on_step_end is not None: - # callback_kwargs = {} - # for k in callback_on_step_end_tensor_inputs: - # callback_kwargs[k] = locals()[k] - # callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - # latents = callback_outputs.pop("latents", latents) - # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + positive_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latents, + "timestep": timestep / 1000, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "return_dict": False, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_normalize, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) return latents @@ -704,23 +712,18 @@ def forward( if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} - # 6. Denoising loop - - # We set the index here to remove DtoH sync, helpful especially during compilation - # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 - - latents = self.denoising( - latents, - timesteps, - prompt_embeds, - negative_prompt_embeds, - guidance_scale, - do_classifier_free_guidance, - callback_on_step_end, - callback_on_step_end_tensor_inputs, - text_ids, - negative_text_ids, - latent_image_ids, + # 6. Denoising loop using diffuse method + latents = self.diffuse( + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds if do_classifier_free_guidance else None, + text_ids=text_ids, + negative_text_ids=negative_text_ids if do_classifier_free_guidance else None, + latent_image_ids=latent_image_ids, + do_true_cfg=do_classifier_free_guidance, + guidance_scale=guidance_scale, + cfg_normalize=False, ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/qwen_image/__init__.py b/vllm_omni/diffusion/models/qwen_image/__init__.py index 84fa2259d4..4b823ec75d 100644 --- a/vllm_omni/diffusion/models/qwen_image/__init__.py +++ b/vllm_omni/diffusion/models/qwen_image/__init__.py @@ -2,6 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Qwen Image diffusion model components.""" +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import ( QwenImagePipeline, get_qwen_image_post_process_func, @@ -11,6 +14,7 @@ ) __all__ = [ + "QwenImageCFGParallelMixin", "QwenImagePipeline", "QwenImageTransformer2DModel", "get_qwen_image_post_process_func", diff --git a/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py b/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py new file mode 100644 index 0000000000..9a882f7bf0 --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/cfg_parallel.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""CFG Parallel Mixin for Qwen Image series +Shared by +- QwenImagePipeline +- QwenImageEditPipeline +- QwenImageEditPlusPipeline +- QwenImageLayeredPipeline +""" + +import logging +from typing import Any + +import torch + +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size + +logger = logging.getLogger(__name__) + + +class QwenImageCFGParallelMixin(CFGParallelMixin): + """ + Base Mixin class for Qwen Image pipelines providing shared CFG methods. + """ + + def diffuse( + self, + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + latents: torch.Tensor, + img_shapes: torch.Tensor, + txt_seq_lens: torch.Tensor, + negative_txt_seq_lens: torch.Tensor, + timesteps: torch.Tensor, + do_true_cfg: bool, + guidance: torch.Tensor, + true_cfg_scale: float, + image_latents: torch.Tensor | None = None, + cfg_normalize: bool = True, + additional_transformer_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + """ + Diffusion loop with optional classifier-free guidance. + + Args: + prompt_embeds: Positive prompt embeddings + prompt_embeds_mask: Mask for positive prompt + negative_prompt_embeds: Negative prompt embeddings + negative_prompt_embeds_mask: Mask for negative prompt + latents: Noise latents to denoise + img_shapes: Image shape information + txt_seq_lens: Text sequence lengths for positive prompts + negative_txt_seq_lens: Text sequence lengths for negative prompts + timesteps: Diffusion timesteps + do_true_cfg: Whether to apply CFG + guidance: Guidance scale tensor + true_cfg_scale: CFG scale factor + image_latents: Conditional image latents for editing (default: None) + cfg_normalize: Whether to normalize CFG output (default: True) + additional_transformer_kwargs: Extra kwargs to pass to transformer (default: None) + + Returns: + Denoised latents + """ + self.scheduler.set_begin_index(0) + self.transformer.do_true_cfg = do_true_cfg + additional_transformer_kwargs = additional_transformer_kwargs or {} + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + self._current_timestep = t + + # Broadcast timestep to match batch size + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + # Concatenate image latents with noise latents if available (for editing pipelines) + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": prompt_embeds_mask, + "encoder_hidden_states": prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + **additional_transformer_kwargs, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + **additional_transformer_kwargs, + } + else: + negative_kwargs = None + + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + true_cfg_scale, + positive_kwargs, + negative_kwargs, + cfg_normalize, + output_slice, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + return latents + + def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool): + """ + Validate whether CFG parallel is properly configured for the current generation request. + + When CFG parallel is enabled (cfg_parallel_world_size > 1), this method verifies that the necessary + conditions are met for correct parallel execution. If validation fails, a warning is + logged to help identify configuration issues. + + Args: + true_cfg_scale: The classifier-free guidance scale value. Must be > 1 for CFG to + have an effect. + has_neg_prompt: Whether negative prompts or negative prompt embeddings are provided. + Required for CFG to perform unconditional prediction. + + Returns: + True if CFG parallel is disabled or all validation checks pass, False otherwise. + + Note: + When CFG parallel is disabled (world_size == 1), this method always returns True + as no parallel-specific validation is needed. + """ + if get_classifier_free_guidance_world_size() == 1: + return True + + if true_cfg_scale <= 1: + logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.") + return False + + if not has_neg_prompt: + logger.warning( + "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings." + ) + return False + return True diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index e0a37b8bc8..d85d98b5bf 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -25,13 +25,11 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) @@ -239,9 +237,7 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) -class QwenImagePipeline( - nn.Module, -): +class QwenImagePipeline(nn.Module, QwenImageCFGParallelMixin): def __init__( self, *, @@ -536,109 +532,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def diffuse( - self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - ): - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - self._current_timestep = t - - # Broadcast timestep to match batch size - timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - self.transformer.do_true_cfg = do_true_cfg - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - else: - local_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - - else: - # Forward pass for positive prompt (or unconditional if no CFG) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - # Forward pass for negative prompt (CFG) - if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - return latents - def forward( self, req: OmniDiffusionRequest, @@ -720,6 +613,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, @@ -790,6 +685,12 @@ def forward( do_true_cfg, guidance, true_cfg_scale, + image_latents=None, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 4241370bab..78fd92c9d5 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -25,14 +25,12 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, @@ -213,10 +211,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class QwenImageEditPipeline( - nn.Module, - SupportImageInput, -): +class QwenImageEditPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): def __init__( self, *, @@ -598,118 +593,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def diffuse( - self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - image_latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - ): - """Diffusion loop with optional image conditioning.""" - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - self._current_timestep = t - # broadcast to batch dimension and place on same device/dtype as latents - timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - - # Concatenate image latents with noise latents if available - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - self.transformer.do_true_cfg = do_true_cfg # used in teacache hook - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - else: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - local_pred = local_pred[:, : latents.size(1)] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - - else: - # Forward pass for positive prompt (or unconditional if no CFG) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - # Forward pass for negative prompt (CFG) - if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - return latents - def forward( self, req: OmniDiffusionRequest, @@ -823,6 +706,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt=prompt, image=prompt_image, # Use resized image for prompt encoding @@ -886,7 +771,6 @@ def forward( negative_prompt_embeds, negative_prompt_embeds_mask, latents, - image_latents, img_shapes, txt_seq_lens, negative_txt_seq_lens, @@ -894,6 +778,12 @@ def forward( do_true_cfg, guidance, true_cfg_scale, + image_latents=image_latents, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 9402c1f7ce..00e7758029 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -23,14 +23,12 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import ( calculate_dimensions, @@ -168,7 +166,7 @@ def post_process_func( return post_process_func -class QwenImageEditPlusPipeline(nn.Module, SupportImageInput): +class QwenImageEditPlusPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): def __init__( self, *, @@ -530,116 +528,6 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def diffuse( - self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - image_latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - ): - """Diffusion loop with optional image conditioning.""" - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - self._current_timestep = t - # broadcast to batch dimension and place on same device/dtype as latents - timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - - # Concatenate image latents with noise latents if available - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - self.transformer.do_true_cfg = do_true_cfg - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - else: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - local_pred = local_pred[:, : latents.size(1)] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - - else: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - return latents - def forward( self, req: OmniDiffusionRequest, @@ -773,6 +661,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt=prompt, image=condition_images, # Use condition images for prompt encoding @@ -840,7 +730,6 @@ def forward( negative_prompt_embeds, negative_prompt_embeds_mask, latents, - image_latents, img_shapes, txt_seq_lens, negative_txt_seq_lens, @@ -848,6 +737,12 @@ def forward( do_true_cfg, guidance, true_cfg_scale, + image_latents=image_latents, + cfg_normalize=True, + additional_transformer_kwargs={ + "return_dict": False, + "attention_kwargs": self.attention_kwargs, + }, ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 1001e8a140..d200764ebf 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -22,17 +22,15 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import ( AutoencoderKLQwenImage, ) +from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( + QwenImageCFGParallelMixin, +) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) @@ -191,7 +189,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class QwenImageLayeredPipeline(nn.Module, SupportImageInput): +class QwenImageLayeredPipeline(nn.Module, SupportImageInput, QwenImageCFGParallelMixin): color_format = "RGBA" def __init__( @@ -553,138 +551,6 @@ def _unpack_latents(latents, height, width, layers, vae_scale_factor): return latents - def diffuse( - self, - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - image_latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, - true_cfg_scale, - cfg_normalize, - is_rgb, - ): - self.scheduler.set_begin_index(0) - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - self.transformer.do_true_cfg = do_true_cfg # used in teacache hook - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - additional_t_cond=is_rgb, - return_dict=False, - )[0] - local_pred = local_pred[:, : latents.size(1)] - else: - local_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - additional_t_cond=is_rgb, - return_dict=False, - )[0] - local_pred = local_pred[:, : latents.size(1)] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - if cfg_normalize: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - noise_pred = comb_pred - - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - cfg_group.broadcast(latents, src=0) - - else: - # Forward pass for positive prompt (or unconditional if no CFG) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, - attention_kwargs=self.attention_kwargs, - additional_t_cond=is_rgb, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, - attention_kwargs=self.attention_kwargs, - additional_t_cond=is_rgb, - return_dict=False, - )[0] - - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - if cfg_normalize: - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - else: - noise_pred = comb_pred - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - # no need callback now - return latents - @property def guidance_scale(self): return self._guidance_scale @@ -840,6 +706,8 @@ def forward( ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) + prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, @@ -924,7 +792,6 @@ def forward( negative_prompt_embeds, negative_prompt_embeds_mask, latents, - image_latents, img_shapes, txt_seq_lens, negative_txt_seq_lens, @@ -932,8 +799,13 @@ def forward( do_true_cfg, guidance, true_cfg_scale, - cfg_normalize, - is_rgb, + image_latents=image_latents, + cfg_normalize=cfg_normalize, + additional_transformer_kwargs={ + "return_dict": False, + "additional_t_cond": is_rgb, + "attention_kwargs": self.attention_kwargs, + }, ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 34a0eb6c14..3668c132f5 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -16,6 +16,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.sd3.sd3_transformer import ( @@ -126,9 +127,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline( - nn.Module, -): +class StableDiffusion3Pipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -498,15 +497,35 @@ def interrupt(self): def diffuse( self, - prompt_embeds, - pooled_prompt_embeds, - negative_prompt_embeds, - negative_pooled_prompt_embeds, - latents, - timesteps, - do_cfg, - ): + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_pooled_prompt_embeds: torch.Tensor, + do_true_cfg: bool, + guidance_scale: float, + cfg_normalize: bool = False, + ) -> torch.Tensor: + """ + Diffusion loop with optional classifier-free guidance. + + Args: + latents: Noise latents to denoise + timesteps: Diffusion timesteps + prompt_embeds: Positive prompt embeddings + pooled_prompt_embeds: Pooled positive prompt embeddings + negative_prompt_embeds: Negative prompt embeddings + negative_pooled_prompt_embeds: Pooled negative prompt embeddings + do_true_cfg: Whether to apply CFG + guidance_scale: CFG scale factor + cfg_normalize: Whether to normalize CFG output (default: False) + + Returns: + Denoised latents + """ self.scheduler.set_begin_index(0) + for _, t in enumerate(timesteps): if self.interrupt: continue @@ -515,29 +534,36 @@ def diffuse( # Broadcast timestep to match batch size timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) - transformer_kwargs = { + positive_kwargs = { "hidden_states": latents, "timestep": timestep, "encoder_hidden_states": prompt_embeds, "pooled_projections": pooled_prompt_embeds, "return_dict": False, } - - noise_pred = self.transformer(**transformer_kwargs)[0] - - if do_cfg: - neg_transformer_kwargs = { + if do_true_cfg: + negative_kwargs = { "hidden_states": latents, "timestep": timestep, "encoder_hidden_states": negative_prompt_embeds, "pooled_projections": negative_pooled_prompt_embeds, "return_dict": False, } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg, + guidance_scale, + positive_kwargs, + negative_kwargs, + cfg_normalize, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - neg_noise_pred = self.transformer(**neg_transformer_kwargs)[0] - noise_pred = neg_noise_pred + self.guidance_scale * (noise_pred - neg_noise_pred) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] return latents def forward( @@ -644,14 +670,17 @@ def forward( timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) self._num_timesteps = len(timesteps) + # Denoising loop using diffuse method latents = self.diffuse( - prompt_embeds, - pooled_prompt_embeds, - negative_prompt_embeds, - negative_pooled_prompt_embeds, - latents, - timesteps, - do_cfg, + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds if do_cfg else None, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds if do_cfg else None, + do_true_cfg=do_cfg, + guidance_scale=self.guidance_scale, + cfg_normalize=False, ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 562d8eec51..b902bc692e 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -7,7 +7,7 @@ import logging import os from collections.abc import Iterable -from typing import cast +from typing import Any, cast import PIL.Image import torch @@ -18,12 +18,14 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -184,7 +186,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class Wan22Pipeline(nn.Module): +class Wan22Pipeline(nn.Module, CFGParallelMixin): def __init__( self, *, @@ -585,26 +587,44 @@ def forward( latent_model_input = latents.to(dtype) timestep = t.expand(latents.shape[0]) - noise_pred = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if current_guidance_scale > 1.0 and negative_prompt_embeds is not None: - noise_uncond = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) - - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() self._current_timestep = None # For I2V mode: blend final latents with condition @@ -629,6 +649,21 @@ def forward( return DiffusionOutput(output=output) + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use (transformer or transformer_2) + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + def encode_prompt( self, prompt: str | list[str], diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index ea045e5d67..1aed9b75de 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -6,7 +6,7 @@ import logging import os from collections.abc import Iterable -from typing import cast +from typing import Any, cast import numpy as np import PIL.Image @@ -18,6 +18,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -29,6 +30,7 @@ ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -136,7 +138,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class Wan22I2VPipeline(nn.Module, SupportImageInput): +class Wan22I2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin): """ Wan2.2 Image-to-Video Pipeline. @@ -484,31 +486,46 @@ def forward( latent_model_input = torch.cat([latents, condition], dim=1).to(dtype) timestep = t.expand(latents.shape[0]) - # Forward pass - noise_pred = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # Classifier-free guidance - if current_guidance_scale > 1.0 and negative_prompt_embeds is not None: - noise_uncond = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) - - # Scheduler step - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "encoder_hidden_states_image": image_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "encoder_hidden_states_image": image_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() self._current_timestep = None # For expand_timesteps mode, blend final latents with condition @@ -533,6 +550,21 @@ def forward( return DiffusionOutput(output=output) + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use (transformer or transformer_2) + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + def encode_prompt( self, prompt: str | list[str], diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 5944b678a0..d32b7d697c 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -19,7 +19,7 @@ import logging import os from collections.abc import Iterable -from typing import cast +from typing import Any, cast import numpy as np import PIL.Image @@ -31,6 +31,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -42,6 +43,7 @@ ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniTextPrompt +from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -126,7 +128,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class Wan22TI2VPipeline(nn.Module, SupportImageInput): +class Wan22TI2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin): """ Wan2.2 Text-Image-to-Video (TI2V) Pipeline. @@ -399,29 +401,44 @@ def forward( temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) - # Forward pass - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # Classifier-free guidance - if guidance_scale > 1.0 and negative_prompt_embeds is not None: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - # Scheduler step - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": self.transformer, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": self.transformer, + } + else: + negative_kwargs = None + + # Predict noise with automatic CFG parallel handling + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + # Wan2.2 is prone to out of memory errors when predicting large videos + # so we empty the cache here to avoid OOM before vae decoding. + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() self._current_timestep = None # For I2V mode, blend final latents with condition @@ -446,6 +463,21 @@ def forward( return DiffusionOutput(output=output) + def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor: + """ + Forward pass through transformer to predict noise. + + Args: + current_model: The transformer model to use + **kwargs: Arguments to pass to the transformer + + Returns: + Predicted noise tensor + """ + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] + def encode_prompt( self, prompt: str | list[str], diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 5613bdaeb5..083114ee40 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -225,7 +225,12 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Scheduler flow_shift for video models (e.g., 5.0 for 720p, 12.0 for 480p).", ) omni_config_group.add_argument( - "--cfg-parallel-size", type=int, default=1, help="Number of GPUs for CFG parallel computation" + "--cfg-parallel-size", + type=int, + default=1, + choices=[1, 2], + help="Number of devices for CFG parallel computation for diffusion models. " + "Equivalent to setting DiffusionParallelConfig.cfg_parallel_size.", ) return serve_parser