From 00b5a4ce487aa57fb96cb50364b2582240f42be9 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Tue, 3 Mar 2026 15:20:42 +0800 Subject: [PATCH 01/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- docs/models/supported_models.md | 1 + .../diffusion/parallelism_acceleration.md | 1 + docs/user_guide/diffusion_acceleration.md | 1 + vllm_omni/diffusion/models/flux2/__init__.py | 17 + .../models/flux2/flux2_transformer.py | 764 ++++++++++++ .../diffusion/models/flux2/pipeline_flux2.py | 1078 +++++++++++++++++ vllm_omni/diffusion/registry.py | 6 + 7 files changed, 1868 insertions(+) create mode 100644 vllm_omni/diffusion/models/flux2/__init__.py create mode 100644 vllm_omni/diffusion/models/flux2/flux2_transformer.py create mode 100644 vllm_omni/diffusion/models/flux2/pipeline_flux2.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 980488852f9..27adbfeb83e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -44,6 +44,7 @@ th { |`GlmImageForConditionalGeneration` | GLM-Image | `zai-org/GLM-Image` | |`NextStep11Pipeline` | NextStep-1.1 | `stepfun-ai/NextStep-1.1` | |`MiMoAudioForConditionalGeneration` | MiMo-Audio-7B-Instruct | `XiaomiMiMo/MiMo-Audio-7B-Instruct` | +|`Flux2Pipeline` | FLUX.2-dev | `black-forest-labs/FLUX.2-dev` | ## List of Supported Models for NPU diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index b0225e2ffb9..2c11c5ecb50 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -35,6 +35,7 @@ The following table shows which models are currently supported by parallelism me | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ | | **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | +| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | !!! note "TP Limitations for Diffusion Models" We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 9fd14a2b070..56351d05298 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -68,6 +68,7 @@ The following table shows which models are currently supported by each accelerat | **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | **NextStep-1.1** | `stepfun-ai/NextStep-1.1` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | +| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ### VideoGen diff --git a/vllm_omni/diffusion/models/flux2/__init__.py b/vllm_omni/diffusion/models/flux2/__init__.py new file mode 100644 index 00000000000..bce893e69d9 --- /dev/null +++ b/vllm_omni/diffusion/models/flux2/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Flux2 diffusion model components.""" + +from vllm_omni.diffusion.models.flux2.flux2_transformer import ( + Flux2Transformer2DModel, +) +from vllm_omni.diffusion.models.flux2.pipeline_flux2 import ( + Flux2Pipeline, + get_flux2_post_process_func, +) + +__all__ = [ + "Flux2Pipeline", + "Flux2Transformer2DModel", + "get_flux2_post_process_func", +] diff --git a/vllm_omni/diffusion/models/flux2/flux2_transformer.py b/vllm_omni/diffusion/models/flux2/flux2_transformer.py new file mode 100644 index 00000000000..478d6cbaf7a --- /dev/null +++ b/vllm_omni/diffusion/models/flux2/flux2_transformer.py @@ -0,0 +1,764 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from types import SimpleNamespace +from typing import Any + +import torch +from diffusers.models.embeddings import ( + TimestepEmbedding, + Timesteps, + get_1d_rotary_pos_embed, +) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.utils import is_torch_npu_available +from torch import nn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + + +class Flux2SwiGLU(nn.Module): + """SwiGLU activation used by Flux2.""" + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: float = 3.0, + inner_dim: int | None = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + self.linear_in = MergedColumnParallelLinear( + dim, + [inner_dim, inner_dim], + bias=bias, + return_bias=False, + ) + self.act_fn = Flux2SwiGLU() + self.linear_out = RowParallelLinear( + inner_dim, + dim_out, + bias=bias, + input_is_parallel=True, + return_bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + return self.linear_out(x) + + +class Flux2Attention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dropout = dropout + self.added_kv_proj_dim = added_kv_proj_dim + + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=bias, + ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + self.inner_dim, + self.out_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ), + nn.Dropout(dropout), + ] + ) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=added_proj_bias, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + self.to_add_out = RowParallelLinear( + self.inner_dim, + query_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.query_num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.kv_num_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + qkv, _ = self.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = encoder_qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (self.query_num_heads, -1)) + key = key.unflatten(-1, (self.kv_num_heads, -1)) + value = value.unflatten(-1, (self.kv_num_heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (self.add_query_num_heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.add_kv_num_heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.add_kv_num_heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + if encoder_hidden_states is not None: + context_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [context_len, hidden_states.shape[1] - context_len], + dim=1, + ) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + return hidden_states + + +class Flux2ParallelSelfAttention(nn.Module): + """ + Parallel attention block that fuses QKV projections with MLP input projections. + """ + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + self.to_qkv_mlp_proj = ColumnParallelLinear( + self.query_dim, + self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, + bias=bias, + gather_output=True, + ) + self.mlp_act_fn = Flux2SwiGLU() + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = ColumnParallelLinear( + self.inner_dim + self.mlp_hidden_dim, + self.out_dim, + bias=out_bias, + gather_output=True, + ) + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states, _ = self.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, + [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor], + dim=-1, + ) + + query, key, value = qkv.chunk(3, dim=-1) + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + attn_output = self.attn(query, key, value, attn_metadata) + attn_output = attn_output.flatten(2, 3).to(query.dtype) + + mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1) + hidden_states, _ = self.to_out(hidden_states) + return hidden_states + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + split_hidden_states: bool = False, + text_seq_len: int | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + return hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa + + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # Expected ids shape: [S, len(self.axes_dim)] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] + for i in range(len(self.axes_dim)): + freqs_cis = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + use_real=False, + freqs_dtype=freqs_dtype, + ) + cos_out.append(freqs_cis.real) + sin_out.append(freqs_cis.imag) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): + super().__init__() + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + else: + self.guidance_embedder = None + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor | None) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) + + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) + return timesteps_emb + guidance_emb + return timesteps_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2Transformer2DModel(nn.Module): + """ + The Transformer model introduced in Flux 2. + """ + + _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } + + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: int | None = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + guidance_embeds: bool = True, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.config = SimpleNamespace( + patch_size=patch_size, + in_channels=in_channels, + out_channels=self.out_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + timestep_guidance_channels=timestep_guidance_channels, + mlp_ratio=mlp_ratio, + axes_dims_rope=axes_dims_rope, + rope_theta=rope_theta, + eps=eps, + guidance_embeds=guidance_embeds, + ) + + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, + ) + + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + joint_attention_kwargs = joint_attention_kwargs or {} + + num_txt_tokens = encoder_hidden_states.shape[1] + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + if is_torch_npu_available(): + freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) + image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) + freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) + text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) + else: + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = hidden_states[:, num_txt_tokens:, ...] + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "to_qkvkv_mlp_proj" in name: + name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") + if "to_qkv_mlp_proj" in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py new file mode 100644 index 00000000000..e057c985caa --- /dev/null +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -0,0 +1,1078 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +import json +import logging +import math +import os +from collections.abc import Callable, Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.flux2.pipeline_flux2 import UPSAMPLING_MAX_IMAGE_SIZE +from diffusers.pipelines.flux2.system_messages import ( + SYSTEM_MESSAGE, + SYSTEM_MESSAGE_UPSAMPLING_I2I, + SYSTEM_MESSAGE_UPSAMPLING_T2I, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, PixtralProcessor +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +logger = logging.getLogger(__name__) + + +class Flux2ImageProcessor(VaeImageProcessor): + """Image processor to preprocess the reference image for Flux2.""" + + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_normalize: bool = True, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_normalize=do_normalize, + do_convert_rgb=do_convert_rgb, + ) + + @staticmethod + def check_image_input( + image: PIL.Image.Image, + max_aspect_ratio: int = 8, + min_side_length: int = 64, + max_area: int = 1024 * 1024, + ) -> PIL.Image.Image: + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") + + width, height = image.size + if width < min_side_length or height < min_side_length: + raise ValueError(f"Image too small: {width}x{height}. Both dimensions must be at least {min_side_length}px") + + aspect_ratio = max(width / height, height / width) + if aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Aspect ratio too extreme: {width}x{height} (ratio: {aspect_ratio:.1f}:1). " + f"Maximum allowed ratio is {max_aspect_ratio}:1" + ) + + if width * height > max_area: + logger.warning("Image area exceeds recommended maximum; resizing will be applied.") + + return image + + @staticmethod + def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + scale = math.sqrt(target_area / (image_width * image_height)) + width = int(image_width * scale) + height = int(image_height * scale) + return image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + @staticmethod + def _resize_if_exceeds_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + if image_width * image_height <= target_area: + return image + return Flux2ImageProcessor._resize_to_target_area(image, target_area) + + def _resize_and_crop(self, image: PIL.Image.Image, width: int, height: int) -> PIL.Image.Image: + image_width, image_height = image.size + left = (image_width - width) // 2 + top = (image_height - height) // 2 + right = left + width + bottom = top + height + return image.crop((left, top, right, bottom)) + + @staticmethod + def concatenate_images(images: list[PIL.Image.Image]) -> PIL.Image.Image: + if len(images) == 1: + return images[0].copy() + + images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] + total_width = sum(img.width for img in images) + max_height = max(img.height for img in images) + background_color = (255, 255, 255) + new_img = PIL.Image.new("RGB", (total_width, max_height), background_color) + + x_offset = 0 + for img in images: + y_offset = (max_height - img.height) // 2 + new_img.paste(img, (x_offset, y_offset)) + x_offset += img.width + + return new_img + + +def get_flux2_post_process_func( + od_config: OmniDiffusionConfig, +): + if od_config.output_type == "latent": + return lambda x: x + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func(images: torch.Tensor): + return image_processor.postprocess(images) + + return post_process_func + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 +def format_input( + prompts: list[str], + system_message: str = SYSTEM_MESSAGE, + images: list[PIL.Image.Image] | list[list[PIL.Image.Image]] = None, +): + """ + Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images + to the input. + + Args: + prompts: List of text prompts + system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) + images (optional): List of images to add to the input. + + Returns: + List of conversations, where each conversation is a list of message dicts + """ + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + if images is None or len(images) == 0: + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + else: + assert len(images) == len(prompts), "Number of images must match number of prompts" + messages = [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + ] + for _ in cleaned_txt + ] + + for i, (el, images) in enumerate(zip(messages, images)): + # optionally add the images per batch element. + if images is not None: + el.append( + { + "role": "user", + "content": [{"type": "image", "image": image_obj} for image_obj in images], + } + ) + # add the text. + el.append( + { + "role": "user", + "content": [{"type": "text", "text": cleaned_txt[i]}], + } + ) + + return messages + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19 +def _validate_and_process_images( + images: list[list[PIL.Image.Image]] | list[PIL.Image.Image], + image_processor: Flux2ImageProcessor, + upsampling_max_image_size: int, +) -> list[list[PIL.Image.Image]]: + # Simple validation: ensure it's a list of PIL images or list of lists of PIL images + if not images: + return [] + + # Check if it's a list of lists or a list of images + if isinstance(images[0], PIL.Image.Image): + # It's a list of images, convert to list of lists + images = [[im] for im in images] + + # potentially concatenate multiple images to reduce the size + images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images] + + # cap the pixels + images = [ + [image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i] + for img_i in images + ] + return images + + +# Taken from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator = None, sample_mode: str = "sample"): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2Pipeline(nn.Module, CFGParallelMixin, SupportImageInput): + """Flux2 pipeline for text-to-image generation.""" + + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + support_image_input = True + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self._execution_device = get_local_device() + model = od_config.model + # Check if model is a local path + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.text_encoder = Mistral3ForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.tokenizer = PixtralProcessor.from_pretrained( + model, subfolder="tokenizer", local_files_only=local_files_only + ) + self.vae = AutoencoderKLFlux2.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self._execution_device + ) + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel) + self.transformer = Flux2Transformer2DModel(**transformer_kwargs) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + self.system_message = SYSTEM_MESSAGE + self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I + self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I + self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE + + self._guidance_scale = None + self._attention_kwargs = None + self._num_timesteps = None + self._current_timestep = None + self._interrupt = False + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + system_message: str = SYSTEM_MESSAGE, + hidden_states_layers: list[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_positions = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_positions) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + layer_ids = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, layer_ids) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def upsample_prompt( + self, + prompt: str | list[str], + images: list[PIL.Image.Image] | list[list[PIL.Image.Image]] = None, + temperature: float = 0.15, + device: torch.device = None, + ) -> list[str]: + prompt = [prompt] if isinstance(prompt, str) else prompt + device = self.text_encoder.device if device is None else device + + # Set system message based on whether images are provided + if images is None or len(images) == 0 or images[0] is None: + system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I + else: + system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I + + # Validate and process the input images + if images: + images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size) + + # Format input messages + messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) + + # Process all messages at once + # with image processing a too short max length can throw an error in here. + inputs = self.tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=2048, + ) + + # Move to device + inputs["input_ids"] = inputs["input_ids"].to(device) + inputs["attention_mask"] = inputs["attention_mask"].to(device) + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype) + + # Generate text using the model's generate method + generated_ids = self.text_encoder.generate( + **inputs, + max_new_tokens=512, + do_sample=True, + temperature=temperature, + use_cache=True, + ) + + # Decode only the newly generated tokens (skip input tokens) + # Extract only the generated portion + input_length = inputs["input_ids"].shape[1] + generated_tokens = generated_ids[:, input_length:] + + upsampled_prompt = self.tokenizer.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return upsampled_prompt + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int, ...] = (10, 20, 30), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + "`height` and `width` have to be divisible by %s but are %s and %s. " + "Dimensions will be resized accordingly", + self.vae_scale_factor * 2, + height, + width, + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, " + f"but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + image: PIL.Image.Image | list[PIL.Image.Image] | None = None, + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int, ...] = (10, 20, 30), + caption_upsample_temperature: float = None, + ) -> DiffusionOutput: + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + pass # use image from param list + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers) + + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at list one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + torch.stack(req_negative_prompt_embeds) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + if caption_upsample_temperature: + prompt = self.upsample_prompt(prompt, images=image, temperature=caption_upsample_temperature, device=device) + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + self._current_timestep = None + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + if output_type == "latent": + image = latents + else: + if latents.dtype != self.vae.dtype: + latents = latents.to(self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index ed99b1473b3..b88c237fbd2 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -110,6 +110,11 @@ "pipeline_omnigen2", "OmniGen2Pipeline", ), + "Flux2Pipeline": ( + "flux2", + "pipeline_flux2", + "Flux2Pipeline", + ), } @@ -296,6 +301,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "NextStep11Pipeline": "get_nextstep11_post_process_func", "FluxPipeline": "get_flux_post_process_func", "OmniGen2Pipeline": "get_omnigen2_post_process_func", + "Flux2Pipeline": "get_flux2_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = { From ae1d8d2832bcbbaa29c53e2aad77da7ac3c6799e Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Tue, 3 Mar 2026 15:31:37 +0800 Subject: [PATCH 02/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- vllm_omni/diffusion/models/flux2/pipeline_flux2.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index e057c985caa..002f041f2fc 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -1054,6 +1054,15 @@ def forward( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + self._current_timestep = None latents = self._unpack_latents_with_ids(latents, latent_ids) From 7ec051f50e30224a446724f6126ddbd2627371b5 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Tue, 3 Mar 2026 16:23:16 +0800 Subject: [PATCH 03/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- vllm_omni/diffusion/models/flux2/pipeline_flux2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index 002f041f2fc..7194ef095b0 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -155,7 +155,7 @@ def format_input( prompts: list[str], system_message: str = SYSTEM_MESSAGE, images: list[PIL.Image.Image] | list[list[PIL.Image.Image]] = None, -): +) -> list[list[dict[str, Any]]]: """ Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images to the input. @@ -166,7 +166,7 @@ def format_input( images (optional): List of images to add to the input. Returns: - List of conversations, where each conversation is a list of message dicts + `list[list[dict[str, Any]]]`: List of conversations, where each conversation is a list of message dicts """ # Remove [IMG] tokens from prompts to avoid Pixtral validation issues # when truncation is enabled. The processor counts [IMG] tokens and fails @@ -271,7 +271,7 @@ def retrieve_timesteps( timesteps: list[int] | None = None, sigmas: list[float] | None = None, **kwargs, -): +) -> tuple[torch.Tensor, int]: r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -292,7 +292,7 @@ def retrieve_timesteps( `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: From ceeec04d1f8e588c54834a9b95ad68caaa0800a8 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Tue, 3 Mar 2026 16:26:55 +0800 Subject: [PATCH 04/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- docs/user_guide/diffusion_acceleration.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 2c36c966c0a..bf2d04863c7 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -68,6 +68,7 @@ The following table shows which models are currently supported by each accelerat | **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | **NextStep-1.1** | `stepfun-ai/NextStep-1.1` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | +| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ### VideoGen From 910c2d3f3a0f3eb521782473aea76980051abda9 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Tue, 3 Mar 2026 18:04:03 +0800 Subject: [PATCH 05/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- vllm_omni/diffusion/models/flux2/pipeline_flux2.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index 7194ef095b0..87550154a5e 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -908,12 +908,6 @@ def forward( # And `torch.stack` automatically raises an exception for us prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError - req_negative_prompt_embeds = [ - p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts - ] - if any(p is not None for p in req_negative_prompt_embeds): - torch.stack(req_negative_prompt_embeds) - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, From 496276b05b87b78152d2664fee4112cc53520882 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Wed, 4 Mar 2026 14:33:44 +0800 Subject: [PATCH 06/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- .../diffusion/models/flux2/pipeline_flux2.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index 87550154a5e..e80e9940b25 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -25,7 +25,6 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel @@ -149,8 +148,7 @@ def post_process_func(images: torch.Tensor): return post_process_func -# Adapted from -# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 +# Copied from diffusers.pipelines.flux2.pipeline_flux2.format_input def format_input( prompts: list[str], system_message: str = SYSTEM_MESSAGE, @@ -216,8 +214,7 @@ def format_input( return messages -# Adapted from -# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19 +# Copied from diffusers.pipelines.flux2.pipeline_flux2._validate_and_process_images def _validate_and_process_images( images: list[list[PIL.Image.Image]] | list[PIL.Image.Image], image_processor: Flux2ImageProcessor, @@ -243,8 +240,7 @@ def _validate_and_process_images( return images -# Taken from -# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: a1, b1 = 8.73809524e-05, 1.89833333 a2, b2 = 0.00016927, 0.45666666 @@ -335,7 +331,7 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator = raise AttributeError("Could not access latents of provided encoder_output") -class Flux2Pipeline(nn.Module, CFGParallelMixin, SupportImageInput): +class Flux2Pipeline(nn.Module, SupportImageInput): """Flux2 pipeline for text-to-image generation.""" _callback_tensor_inputs = ["latents", "prompt_embeds"] @@ -449,6 +445,7 @@ def _get_mistral_3_small_prompt_embeds( return prompt_embeds @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids def _prepare_text_ids( x: torch.Tensor, # (B, L, D) or (L, D) t_coord: torch.Tensor | None = None, @@ -468,6 +465,7 @@ def _prepare_text_ids( return torch.stack(out_ids) @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids def _prepare_latent_ids( latents: torch.Tensor, # (B, C, H, W) ): @@ -500,6 +498,7 @@ def _prepare_latent_ids( return latent_ids @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids def _prepare_image_ids( image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] scale: int = 10, @@ -550,6 +549,7 @@ def _prepare_image_ids( return image_latent_ids @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents def _patchify_latents(latents): batch_size, num_channels_latents, height, width = latents.shape latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) @@ -558,6 +558,7 @@ def _patchify_latents(latents): return latents @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents def _unpatchify_latents(latents): batch_size, num_channels_latents, height, width = latents.shape latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) @@ -566,6 +567,7 @@ def _unpatchify_latents(latents): return latents @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents def _pack_latents(latents): """ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) @@ -577,6 +579,7 @@ def _pack_latents(latents): return latents @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: """ using position ids to scatter tokens into place @@ -602,6 +605,7 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch return torch.stack(x_list, dim=0) + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.upsample_prompt def upsample_prompt( self, prompt: str | list[str], @@ -699,6 +703,7 @@ def encode_prompt( text_ids = text_ids.to(device) return prompt_embeds, text_ids + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if image.ndim != 4: raise ValueError(f"Expected image dims 4, got {image.ndim}.") @@ -712,6 +717,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents def prepare_latents( self, batch_size, @@ -745,6 +751,7 @@ def prepare_latents( latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] return latents, latent_ids + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents def prepare_image_latents( self, images: list[torch.Tensor], From 72a3af8a337c18285e14ab99f00aaafb6bd60c90 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Wed, 4 Mar 2026 15:57:13 +0800 Subject: [PATCH 07/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- .../models/flux2/flux2_transformer.py | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2/flux2_transformer.py b/vllm_omni/diffusion/models/flux2/flux2_transformer.py index 478d6cbaf7a..040f2779a8c 100644 --- a/vllm_omni/diffusion/models/flux2/flux2_transformer.py +++ b/vllm_omni/diffusion/models/flux2/flux2_transformer.py @@ -552,10 +552,6 @@ class Flux2Transformer2DModel(nn.Module): """ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] - packed_modules_mapping = { - "to_qkv": ["to_q", "to_k", "to_v"], - "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], - } def __init__( self, @@ -575,6 +571,7 @@ def __init__( guidance_embeds: bool = True, ): super().__init__() + self.stacked_params_mapping = None self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim self.config = SimpleNamespace( @@ -724,13 +721,15 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), + (".to_qkv.", ".to_q.", "q"), + (".to_qkv.", ".to_k.", "k"), + (".to_qkv.", ".to_v.", "v"), (".add_kv_proj", ".add_q_proj", "q"), (".add_kv_proj", ".add_k_proj", "k"), (".add_kv_proj", ".add_v_proj", "v"), ] + # Expose packed shard mappings for LoRA handling of fused projections. + self.stacked_params_mapping = stacked_params_mapping params_dict = dict(self.named_parameters()) @@ -740,25 +739,32 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - if "to_qkvkv_mlp_proj" in name: - name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") - if "to_qkv_mlp_proj" in name: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - continue + original_name = name + mapped = False for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in original_name: continue - name = name.replace(weight_name, param_name) - param = params_dict[name] + name = original_name.replace(weight_name, param_name) + param = params_dict.get(name) + if param is None: + break weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + mapped = True break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + if mapped: + continue + + name = original_name + if name not in params_dict and ".to_out.0." in name: + name = name.replace(".to_out.0.", ".to_out.") + # Some GGUF checkpoints include quantized tensors for modules that + # are intentionally left unquantized in this model. + param = params_dict.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params From df1782d1570190cb012c0677510789f3a9747020 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Wed, 4 Mar 2026 17:00:35 +0800 Subject: [PATCH 08/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- .../models/flux2/test_flux2_transformer.py | 244 ++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 tests/diffusion/models/flux2/test_flux2_transformer.py diff --git a/tests/diffusion/models/flux2/test_flux2_transformer.py b/tests/diffusion/models/flux2/test_flux2_transformer.py new file mode 100644 index 00000000000..a2d1fe6abd3 --- /dev/null +++ b/tests/diffusion/models/flux2/test_flux2_transformer.py @@ -0,0 +1,244 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm_omni.diffusion.models.flux2.flux2_transformer import ( + Flux2PosEmbed, + Flux2Transformer2DModel, +) + + +# Initialize TP group before tests +@pytest.fixture(scope="function", autouse=True) +def setup_tp_group(): + """Set up TP group for each test function""" + with patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=2): + with patch("vllm.distributed.parallel_state.get_tp_group") as mock_get_tp_group: + mock_tp_group = MagicMock() + mock_tp_group.world_size = 2 + mock_get_tp_group.return_value = mock_tp_group + yield + + +class TestFlux2TransformerWeightLoading: + """Test Flux2Transformer weight loading functionality""" + + def test_weight_loading_tp2(self, setup_tp_group): + """Verify weights load correctly with TP=2""" + # Prepare test data + model = Flux2Transformer2DModel( + num_layers=2, + num_single_layers=1, + num_attention_heads=48, + attention_head_dim=128, + joint_attention_dim=15360, + ) + + # Mock TP=2 weight loading + mock_weights = [] + + # 1. Test regular weight loading + mock_weights.append(("x_embedder.weight", torch.randn(6144, 128))) + mock_weights.append(("context_embedder.weight", torch.randn(6144, 15360))) + mock_weights.append(("proj_out.weight", torch.randn(128, 6144))) + + # 2. Test stacked_params_mapping weight loading + # Full weights - load_weights handles sharding internally + to_qkv_weight = torch.randn(18432, 6144) # Full weights + mock_weights.append(("transformer_blocks.0.attn.to_qkv.weight", to_qkv_weight)) + + # Add_kv_proj weights + add_kv_proj_weight = torch.randn(18432, 6144) # Full weights + mock_weights.append(("transformer_blocks.0.attn.add_kv_proj.weight", add_kv_proj_weight)) + + # 3. Test single block weight loading + single_block_qkv_mlp_weight = torch.randn(18432 + 18432 * 3, 6144) # Full weights + mock_weights.append(("single_transformer_blocks.0.attn.to_qkv_mlp_proj.weight", single_block_qkv_mlp_weight)) + + # Execute weight loading + loaded_params = model.load_weights(mock_weights) + + # Verify + assert len(loaded_params) > 0, "Parameters should be loaded" + + # Verify stacked_params_mapping is correctly set + assert model.stacked_params_mapping is not None + # Should have 6 mappings: 3 for to_qkv + 3 for add_kv_proj + assert len(model.stacked_params_mapping) == 6, "Should have 6 mappings" + + # Verify weights are correctly loaded to corresponding layers + assert hasattr(model.transformer_blocks[0].attn.to_qkv, "weight") + # With TP=2, weight dimension on each GPU should be 18432/2 = 9216 + assert model.transformer_blocks[0].attn.to_qkv.weight.shape == (9216, 6144), ( + f"With TP=2, to_qkv weight dimension should be (9216, 6144), got {model.transformer_blocks[0].attn.to_qkv.weight.shape}" + ) + + +class TestFlux2RopePositionEmbedding: + """Test Flux2 RoPE position embedding functionality""" + + def test_rope_position_embedding(self): + """Verify RoPE produces correct embeddings for 4D coordinates""" + # Prepare test data - use model default configuration + # axes_dims_rope default is (32, 32, 32, 32) + # get_1d_rotary_pos_embed outputs half the dimension for real/imag parts + # So actual output dimension should be (16+16+16+16) = 64 + axes_dims = (32, 32, 32, 32) # Use model default + rope_theta = 2000 # Model default is 2000 + pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims) + + # Create test IDs + seq_len = 10 + ids = torch.randint(0, 100, (seq_len, 4)) # [S, 4] + + # Forward pass + freqs_cos, freqs_sin = pos_embed(ids) + + # Verify output shape - based on model config, expected dimension is 64 + # Each axes_dim=32 outputs 16-dim real part, sum of 4 dimensions = 64 + expected_dim = sum(axes_dims) // 2 # 128/2 = 64 + assert freqs_cos.shape == (seq_len, expected_dim), ( + f"Expected shape {(seq_len, expected_dim)}, got {freqs_cos.shape}" + ) + assert freqs_sin.shape == (seq_len, expected_dim), ( + f"Expected shape {(seq_len, expected_dim)}, got {freqs_sin.shape}" + ) + + # Verify output type - NPU may use float32, other devices use float64 + assert freqs_cos.dtype in [torch.float32, torch.float64], "Should be float type" + assert freqs_sin.dtype in [torch.float32, torch.float64], "Should be float type" + + # Verify value range + assert torch.all(freqs_cos >= -1) and torch.all(freqs_cos <= 1), "cos values should be in [-1, 1]" + assert torch.all(freqs_sin >= -1) and torch.all(freqs_sin <= 1), "sin values should be in [-1, 1]" + + # Verify trigonometric relationship: cos^2 + sin^2 ≈ 1 + cos_sq_sin_sq = freqs_cos**2 + freqs_sin**2 + assert torch.allclose(cos_sq_sin_sq, torch.ones_like(cos_sq_sin_sq), atol=1e-6), "cos^2 + sin^2 should ≈ 1" + + # Verify different positions produce different embeddings + ids_diff = torch.randint(0, 100, (seq_len, 4)) + freqs_cos_diff, freqs_sin_diff = pos_embed(ids_diff) + assert not torch.allclose(freqs_cos, freqs_cos_diff), "Different positions should produce different embeddings" + + # Verify same positions produce same embeddings + ids_same = ids.clone() + freqs_cos_same, freqs_sin_same = pos_embed(ids_same) + assert torch.allclose(freqs_cos, freqs_cos_same), "Same positions should produce same embeddings" + assert torch.allclose(freqs_sin, freqs_sin_same), "Same positions should produce same embeddings" + + +class TestFlux2PackedModuleMapping: + """Test Flux2 packed module mapping functionality""" + + def test_packed_module_mapping(self, setup_tp_group): + """Verify to_qkv packing matches HF checkpoint""" + model = Flux2Transformer2DModel( + num_layers=1, + num_single_layers=0, + num_attention_heads=48, + attention_head_dim=128, + joint_attention_dim=15360, + ) + + # Verify stacked_params_mapping is correctly initialized + model.load_weights([]) # Trigger stacked_params_mapping initialization + assert model.stacked_params_mapping is not None + + # Verify mapping configuration + expected_mappings = [ + (".to_qkv.", ".to_q.", "q"), + (".to_qkv.", ".to_k.", "k"), + (".to_qkv.", ".to_v.", "v"), + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + assert model.stacked_params_mapping == expected_mappings + + # Create mock HF checkpoint weights + hf_weights = [] + + # Mock HF separated Q/K/V weights + attn_block = model.transformer_blocks[0].attn + head_dim = 128 + num_heads = 48 + hidden_size = 6144 + # Full weight dimension + full_shard_size = num_heads * head_dim # 6144 + + # Q projection weights (full weights) + q_weight = torch.randn(full_shard_size, hidden_size) + hf_weights.append(("transformer_blocks.0.attn.to_q.weight", q_weight)) + + # K projection weights (full weights) + k_weight = torch.randn(full_shard_size, hidden_size) + hf_weights.append(("transformer_blocks.0.attn.to_k.weight", k_weight)) + + # V projection weights (full weights) + v_weight = torch.randn(full_shard_size, hidden_size) + hf_weights.append(("transformer_blocks.0.attn.to_v.weight", v_weight)) + + # Mock HF separated add_q/k/v projection weights + add_q_weight = torch.randn(full_shard_size, hidden_size) + hf_weights.append(("transformer_blocks.0.attn.add_q_proj.weight", add_q_weight)) + + add_k_weight = torch.randn(full_shard_size, hidden_size) + hf_weights.append(("transformer_blocks.0.attn.add_k_proj.weight", add_k_weight)) + + add_v_weight = torch.randn(full_shard_size, hidden_size) + hf_weights.append(("transformer_blocks.0.attn.add_v_proj.weight", add_v_weight)) + + # Execute weight loading + loaded_params = model.load_weights(hf_weights) + + # Verify weights are loaded + assert len(loaded_params) > 0 + + # Verify final QKV weights are correctly combined (considering TP sharding) + # With TP=2, dimension on each GPU should be half of full dimension + expected_qkv_shape = (full_shard_size * 3 // 2, hidden_size) # 18432/2 = 9216 + assert attn_block.to_qkv.weight.shape == expected_qkv_shape, ( + f"to_qkv weight dimension should be {expected_qkv_shape}, got {attn_block.to_qkv.weight.shape}" + ) + + expected_add_kv_shape = (full_shard_size * 3 // 2, hidden_size) + assert attn_block.add_kv_proj.weight.shape == expected_add_kv_shape, ( + f"add_kv_proj weight dimension should be {expected_add_kv_shape}, got {attn_block.add_kv_proj.weight.shape}" + ) + + def test_packed_mapping_edge_cases(self, setup_tp_group): + """Test edge cases for packed mapping""" + model = Flux2Transformer2DModel( + num_layers=1, + num_single_layers=1, + num_attention_heads=48, + attention_head_dim=128, + joint_attention_dim=15360, + ) + model.load_weights([]) + + # Test invalid weight names + invalid_weights = [("invalid.weight", torch.randn(10, 10))] + loaded_params = model.load_weights(invalid_weights) + assert len(loaded_params) == 0, "Should not load invalid weights" + + # Test to_out weight renaming + to_out_weight = torch.randn(6144, 6144) + weights = [("transformer_blocks.0.attn.to_out.0.weight", to_out_weight)] + loaded_params = model.load_weights(weights) + + # Check if to_out related weights are loaded + to_out_loaded = any("to_out" in p for p in loaded_params) + assert to_out_loaded, "to_out weights should be correctly renamed and loaded" + + # Test partial weight loading + partial_weights = [ + ("x_embedder.weight", torch.randn(6144, 128)), + ("transformer_blocks.0.attn.to_q.weight", torch.randn(6144, 6144)), # Full weights + ] + loaded_params = model.load_weights(partial_weights) + assert len(loaded_params) == 2, "Should load two weights" + assert "x_embedder.weight" in loaded_params + assert any("to_qkv" in p for p in loaded_params), "to_q should be mapped to to_qkv" From 00322c97f1fd464fca3e25c60748710a85ad79d7 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Wed, 4 Mar 2026 20:56:44 +0800 Subject: [PATCH 09/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- vllm_omni/diffusion/models/flux2/pipeline_flux2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index e80e9940b25..bfcc038514c 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -1017,10 +1017,6 @@ def forward( ) self._num_timesteps = len(timesteps) - # handle guidance - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - # 7. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 @@ -1042,19 +1038,23 @@ def forward( noise_pred = self.transformer( hidden_states=latent_model_input, # (B, image_seq_len, C) timestep=timestep / 1000, - guidance=guidance, + guidance=None, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, # B, text_seq_len, 4 img_ids=latent_image_ids, # B, image_seq_len, 4 - joint_attention_kwargs=self._attention_kwargs, + joint_attention_kwargs=self.attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred[:, : latents.size(1) :] # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if latents.dtype != latents_dtype and torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: From cc26148a8911a2cf2dd76261dc82e80d575fc9d6 Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Thu, 5 Mar 2026 09:20:20 +0800 Subject: [PATCH 10/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- vllm_omni/diffusion/models/flux2/pipeline_flux2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index bfcc038514c..6a0b02abcda 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -763,8 +763,8 @@ def prepare_image_latents( image_latents = [] for image in images: image = image.to(device=device, dtype=dtype) - imagge_latent = self._encode_vae_image(image=image, generator=generator) - image_latents.append(imagge_latent) # (1, 128, 32, 32) + image_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(image_latent) # (1, 128, 32, 32) image_latent_ids = self._prepare_image_ids(image_latents) From 78f08492736812c791030dfe5593ec3ee05a075f Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Thu, 5 Mar 2026 13:32:38 +0800 Subject: [PATCH 11/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- vllm_omni/diffusion/models/flux2/pipeline_flux2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index 6a0b02abcda..7f07adf6f13 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -194,13 +194,13 @@ def format_input( for _ in cleaned_txt ] - for i, (el, images) in enumerate(zip(messages, images)): + for i, (el, batch_images) in enumerate(zip(messages, images)): # optionally add the images per batch element. - if images is not None: + if batch_images is not None: el.append( { "role": "user", - "content": [{"type": "image", "image": image_obj} for image_obj in images], + "content": [{"type": "image", "image": image_obj} for image_obj in batch_images], } ) # add the text. From 3bb68bd97e5cb355c5c1ceafdf07100b7b550f9a Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Wed, 11 Mar 2026 09:37:34 +0800 Subject: [PATCH 12/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- .../offline_inference/text_to_image/README.md | 17 +++++++++++++++++ ...nsformer.py => test_flux2_transformer_tp.py} | 0 2 files changed, 17 insertions(+) rename tests/diffusion/models/flux2/{test_flux2_transformer.py => test_flux2_transformer_tp.py} (100%) diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index ff2dda01020..2513627e852 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -106,6 +106,23 @@ python text_to_image.py \ --seed 42 ``` +### Flux.2-dev Models +To start Flux.2-dev with a single GPU, cpu-offload must be enabled because the total size of its weights exceeds the 80GB memory capacity of the GPU. +```bash +python examples/offline_inference/text_to_image/text_to_image.py \ + --model black-forest-labs/FLUX.2-dev \ + --prompt "a lovely bunny holding a sign that says 'vllm-omni'" \ + --seed 42 \ + --tensor-parallel-size 1 \ + --num-images-per-prompt 1 \ + --num-inference-steps 50 \ + --guidance-scale 4.0 \ + --height 1024 \ + --width 1024 \ + --enable-cpu-offload \ + --output flux2-dev.png +``` + ### Key Arguments **Common arguments:** diff --git a/tests/diffusion/models/flux2/test_flux2_transformer.py b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py similarity index 100% rename from tests/diffusion/models/flux2/test_flux2_transformer.py rename to tests/diffusion/models/flux2/test_flux2_transformer_tp.py From 8f8639e02d18aded678e8f4b83d0399446d50f5a Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Wed, 11 Mar 2026 10:55:04 +0800 Subject: [PATCH 13/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- docs/user_guide/diffusion/parallelism_acceleration.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index bd2a58799f9..4c81adcaf05 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -37,6 +37,7 @@ The following table shows which models are currently supported by parallelism me | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ | N/A | | **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | N/A | +| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | | **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | !!! note "TP Limitations for Diffusion Models" From 1092a9c392a4856a770afb232b472c282bd7b1ee Mon Sep 17 00:00:00 2001 From: wuzhongjian Date: Wed, 11 Mar 2026 11:04:38 +0800 Subject: [PATCH 14/14] [feature]: support Flux.2-dev model Signed-off-by: wuzhongjian --- docs/user_guide/diffusion/parallelism_acceleration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 4c81adcaf05..505f9627fd5 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -37,7 +37,7 @@ The following table shows which models are currently supported by parallelism me | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ | N/A | | **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | N/A | -| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | +| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | | **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | !!! note "TP Limitations for Diffusion Models"