diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index c09705ae05..7e08851812 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -110,7 +110,7 @@ The following tables show which models support each feature: | **FLUX.1-dev** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | **FLUX.2-klein** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | -| **FLUX.2-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | +| **FLUX.2-dev** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index 4796a17692..cc295e8279 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -247,7 +247,7 @@ python examples/offline_inference/text_to_image/text_to_image.py \ #### CFG Parallel Set `--cfg-parallel-size 2` to enable CFG Parallel for faster inference on multi-GPU setups. -See more examples in the [diffusion acceleration user guide](../../../docs/user_guide/diffusion_acceleration.md#using-cfg-parallel). +See more examples in the [cfg_parallel user guide](../../../docs/user_guide/parallelism/cfg_parallel.md#using-cfg-parallel). #### LoRA diff --git a/tests/e2e/online_serving/test_flux_2_dev_expansion.py b/tests/e2e/online_serving/test_flux_2_dev_expansion.py index eba0fbda22..c7140769ba 100644 --- a/tests/e2e/online_serving/test_flux_2_dev_expansion.py +++ b/tests/e2e/online_serving/test_flux_2_dev_expansion.py @@ -29,6 +29,7 @@ NEGATIVE_PROMPT = "low quality, blurry, distorted, deformed, watermark" SINGLE_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "H100"}) +PARALLEL_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=2) def _get_flux_2_dev_feature_cases(model: str): @@ -47,6 +48,20 @@ def _get_flux_2_dev_feature_cases(model: str): id="cache_dit_cpu_offload", marks=SINGLE_CARD_FEATURE_MARKS, ), + pytest.param( + OmniServerParams( + model=model, + server_args=[ + "--cache-backend", + "cache_dit", + "--enable-cpu-offload", + "--cfg-parallel-size", + "2", + ], + ), + id="parallel_cfg_2", + marks=PARALLEL_FEATURE_MARKS, + ), ] diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index 00d3288501..404f05b606 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -25,6 +25,8 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel @@ -333,7 +335,7 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator = raise AttributeError("Could not access latents of provided encoder_output") -class Flux2Pipeline(nn.Module, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin): +class Flux2Pipeline(nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin): """Flux2 pipeline for text-to-image generation.""" _callback_tensor_inputs = ["latents", "prompt_embeds"] @@ -854,6 +856,21 @@ def current_timestep(self): def interrupt(self): return self._interrupt + def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool): + if get_classifier_free_guidance_world_size() == 1: + return True + + if true_cfg_scale <= 1: + logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.") + return False + + if not has_neg_prompt: + logger.warning( + "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings." + ) + return False + return True + def forward( self, req: OmniDiffusionRequest, @@ -921,6 +938,14 @@ def forward( # And `torch.stack` automatically raises an exception for us prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if all(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -958,6 +983,22 @@ def forward( text_encoder_out_layers=text_encoder_out_layers, ) + has_neg_prompt = negative_prompt_embeds is not None or any(req_negative_prompt) + do_true_cfg = self.guidance_scale > 1 and has_neg_prompt + + self.check_cfg_parallel_validity(self.guidance_scale, has_neg_prompt) + negative_text_ids = None + if do_true_cfg: + negative_prompt = req_negative_prompt + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + # 4. process images if image is not None and not isinstance(image, list): image = [image] @@ -1029,6 +1070,9 @@ def forward( guidance_tensor = torch.full([1], self.guidance_scale, device=device, dtype=torch.float32) guidance_tensor = guidance_tensor.expand(latents.shape[0]) + # For editing pipelines, we need to slice the output to remove condition latents + output_slice = latents.size(1) if image_latents is not None else None + # 7. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 @@ -1048,21 +1092,41 @@ def forward( latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) - noise_pred = self.transformer( - hidden_states=latent_model_input, # (B, image_seq_len, C) - timestep=timestep / 1000, - guidance=guidance_tensor, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, # B, text_seq_len, 4 - img_ids=latent_image_ids, # B, image_seq_len, 4 - joint_attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - - noise_pred = noise_pred[:, : latents.size(1) :] - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance_tensor, + "encoder_hidden_states": prompt_embeds, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep / 1000, + "guidance": guidance_tensor, + "encoder_hidden_states": negative_prompt_embeds, + "txt_ids": negative_text_ids, + "img_ids": latent_image_ids, + "joint_attention_kwargs": self.attention_kwargs, + "return_dict": False, + } + else: + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=self.guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + output_slice=output_slice, + ) + + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) if callback_on_step_end is not None: callback_kwargs = {}