Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user_guide/diffusion/parallelism_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The following table shows which models are currently supported by parallelism me
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | | ✅ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | | ✅ |

!!! note "TP Limitations for Diffusion Models"
We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP.
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/diffusion_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The following table shows which models are currently supported by each accelerat
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ |
| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | |

### VideoGen

Expand Down
97 changes: 55 additions & 42 deletions vllm_omni/diffusion/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
Expand Down Expand Up @@ -129,9 +130,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class FluxPipeline(
nn.Module,
):
class FluxPipeline(nn.Module, CFGParallelMixin):
def __init__(
self,
*,
Expand Down Expand Up @@ -500,21 +499,23 @@ def interrupt(self):

def diffuse(
self,
prompt_embeds,
pooled_prompt_embeds,
negative_prompt_embeds,
negative_pooled_prompt_embeds,
latents,
latent_image_ids,
text_ids,
negative_text_ids,
timesteps,
do_true_cfg,
guidance,
true_cfg_scale,
):
prompt_embeds: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
negative_prompt_embeds: torch.Tensor,
negative_pooled_prompt_embeds: torch.Tensor,
latents: torch.Tensor,
latent_image_ids: torch.Tensor,
text_ids: torch.Tensor,
negative_text_ids: torch.Tensor,
timesteps: torch.Tensor,
do_true_cfg: bool,
guidance: torch.Tensor,
true_cfg_scale: float,
cfg_normalize: bool = False,
) -> torch.Tensor:
"""Diffusion loop with optional image conditioning."""
self.scheduler.set_begin_index(0)
self.transformer.do_true_cfg = do_true_cfg
for i, t in enumerate(timesteps):
if self.interrupt:
continue
Expand All @@ -523,36 +524,46 @@ def diffuse(
# broadcast to batch dimension and place on same device/dtype as latents
timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype)

self.transformer.do_true_cfg = do_true_cfg # used in teacache hook
# Forward pass for positive prompt (or unconditional if no CFG)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
positive_kwargs = {
"hidden_states": latents,
"timestep": timestep / 1000,
"guidance": guidance,
"pooled_projections": pooled_prompt_embeds,
"encoder_hidden_states": prompt_embeds,
"txt_ids": text_ids,
"img_ids": latent_image_ids,
"joint_attention_kwargs": self.joint_attention_kwargs,
"return_dict": False,
}

# Forward pass for negative prompt (CFG)
if do_true_cfg:
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
negative_kwargs = {
"hidden_states": latents,
"timestep": timestep / 1000,
"guidance": guidance,
"pooled_projections": negative_pooled_prompt_embeds,
"encoder_hidden_states": negative_prompt_embeds,
"txt_ids": negative_text_ids,
"img_ids": latent_image_ids,
"joint_attention_kwargs": self.joint_attention_kwargs,
"return_dict": False,
}
else:
negative_kwargs = None

# Predict noise with automatic CFG parallel handling
noise_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg,
true_cfg_scale,
positive_kwargs,
negative_kwargs,
cfg_normalize,
)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)

return latents

def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool):
Expand Down Expand Up @@ -613,6 +624,7 @@ def forward(
req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale
)
generator = req.sampling_params.generator or generator
true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale
num_images_per_prompt = (
req.sampling_params.num_outputs_per_prompt
if req.sampling_params.num_outputs_per_prompt > 0
Expand Down Expand Up @@ -723,6 +735,7 @@ def forward(
do_true_cfg,
guidance,
true_cfg_scale,
cfg_normalize=False,
)

self._current_timestep = None
Expand Down