Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user_guide/diffusion_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ The following tables show which models support each feature:
| **FLUX.1-dev** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.2-klein** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **FLUX.2-dev** | ❌ | ❌ | ❌ | | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **FLUX.2-dev** | ❌ | ❌ | ❌ | | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/text_to_image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ python examples/offline_inference/text_to_image/text_to_image.py \
#### CFG Parallel

Set `--cfg-parallel-size 2` to enable CFG Parallel for faster inference on multi-GPU setups.
See more examples in the [diffusion acceleration user guide](../../../docs/user_guide/diffusion_acceleration.md#using-cfg-parallel).
See more examples in the [cfg_parallel user guide](../../../docs/user_guide/parallelism/cfg_parallel.md#using-cfg-parallel).

#### LoRA

Expand Down
15 changes: 15 additions & 0 deletions tests/e2e/online_serving/test_flux_2_dev_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
NEGATIVE_PROMPT = "low quality, blurry, distorted, deformed, watermark"

SINGLE_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "H100"})
PARALLEL_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=2)


def _get_flux_2_dev_feature_cases(model: str):
Expand All @@ -47,6 +48,20 @@ def _get_flux_2_dev_feature_cases(model: str):
id="cache_dit_cpu_offload",
marks=SINGLE_CARD_FEATURE_MARKS,
),
pytest.param(
OmniServerParams(
model=model,
server_args=[
"--cache-backend",
"cache_dit",
"--enable-cpu-offload",
"--cfg-parallel-size",
"2",
],
),
id="parallel_cfg_2",
marks=PARALLEL_FEATURE_MARKS,
),
]


Expand Down
96 changes: 80 additions & 16 deletions vllm_omni/diffusion/models/flux2/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel
Expand Down Expand Up @@ -333,7 +335,7 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator =
raise AttributeError("Could not access latents of provided encoder_output")


class Flux2Pipeline(nn.Module, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin):
class Flux2Pipeline(nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin):
"""Flux2 pipeline for text-to-image generation."""

_callback_tensor_inputs = ["latents", "prompt_embeds"]
Expand Down Expand Up @@ -854,6 +856,21 @@ def current_timestep(self):
def interrupt(self):
return self._interrupt

def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool):
if get_classifier_free_guidance_world_size() == 1:
return True

if true_cfg_scale <= 1:
logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.")
return False

if not has_neg_prompt:
logger.warning(
"CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings."
)
return False
return True

def forward(
self,
req: OmniDiffusionRequest,
Expand Down Expand Up @@ -921,6 +938,14 @@ def forward(
# And `torch.stack` automatically raises an exception for us
prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError

req_negative_prompt_embeds = [
p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts
]
if all(p is not None for p in req_negative_prompt_embeds):
negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError

req_negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts]

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
Expand Down Expand Up @@ -958,6 +983,22 @@ def forward(
text_encoder_out_layers=text_encoder_out_layers,
)

has_neg_prompt = negative_prompt_embeds is not None or any(req_negative_prompt)
do_true_cfg = self.guidance_scale > 1 and has_neg_prompt

self.check_cfg_parallel_validity(self.guidance_scale, has_neg_prompt)
Comment thread
nuclearwu marked this conversation as resolved.
negative_text_ids = None
if do_true_cfg:
negative_prompt = req_negative_prompt
negative_prompt_embeds, negative_text_ids = self.encode_prompt(
Comment thread
nuclearwu marked this conversation as resolved.
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
text_encoder_out_layers=text_encoder_out_layers,
)

# 4. process images
if image is not None and not isinstance(image, list):
image = [image]
Expand Down Expand Up @@ -1029,6 +1070,9 @@ def forward(
guidance_tensor = torch.full([1], self.guidance_scale, device=device, dtype=torch.float32)
guidance_tensor = guidance_tensor.expand(latents.shape[0])

# For editing pipelines, we need to slice the output to remove condition latents
output_slice = latents.size(1) if image_latents is not None else None

# 7. Denoising loop
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
Expand All @@ -1048,21 +1092,41 @@ def forward(
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)

noise_pred = self.transformer(
hidden_states=latent_model_input, # (B, image_seq_len, C)
timestep=timestep / 1000,
guidance=guidance_tensor,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=latent_image_ids, # B, image_seq_len, 4
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

noise_pred = noise_pred[:, : latents.size(1) :]

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
positive_kwargs = {
"hidden_states": latent_model_input,
"timestep": timestep / 1000,
"guidance": guidance_tensor,
"encoder_hidden_states": prompt_embeds,
"txt_ids": text_ids,
"img_ids": latent_image_ids,
"joint_attention_kwargs": self.attention_kwargs,
"return_dict": False,
}
if do_true_cfg:
negative_kwargs = {
"hidden_states": latent_model_input,
"timestep": timestep / 1000,
"guidance": guidance_tensor,
"encoder_hidden_states": negative_prompt_embeds,
"txt_ids": negative_text_ids,
"img_ids": latent_image_ids,
"joint_attention_kwargs": self.attention_kwargs,
"return_dict": False,
}
else:
negative_kwargs = None

noise_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg=do_true_cfg,
true_cfg_scale=self.guidance_scale,
positive_kwargs=positive_kwargs,
negative_kwargs=negative_kwargs,
cfg_normalize=False,
output_slice=output_slice,
)

# Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)

if callback_on_step_end is not None:
callback_kwargs = {}
Expand Down
Loading