diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 980488852f9..27adbfeb83e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -44,6 +44,7 @@ th { |`GlmImageForConditionalGeneration` | GLM-Image | `zai-org/GLM-Image` | |`NextStep11Pipeline` | NextStep-1.1 | `stepfun-ai/NextStep-1.1` | |`MiMoAudioForConditionalGeneration` | MiMo-Audio-7B-Instruct | `XiaomiMiMo/MiMo-Audio-7B-Instruct` | +|`Flux2Pipeline` | FLUX.2-dev | `black-forest-labs/FLUX.2-dev` | ## List of Supported Models for NPU diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index b0225e2ffb9..2c11c5ecb50 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -35,6 +35,7 @@ The following table shows which models are currently supported by parallelism me | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ | | **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | +| **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | !!! note "TP Limitations for Diffusion Models" We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 9fd14a2b070..170def6c17d 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -68,6 +68,7 @@ The following table shows which models are currently supported by each accelerat | **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | **NextStep-1.1** | `stepfun-ai/NextStep-1.1` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ### VideoGen diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index f014d073e82..bea2f0caa30 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -346,6 +346,84 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_flux2_klein(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for FLUX.2-klein-4B pipeline. + + Args: + pipeline: The FLUX.2-klein-4B pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + Returns: + A refresh function that can be called with a new ``num_inference_steps`` + to update the cache context for the pipeline. + """ + # Build DBCacheConfig for transformer + db_cache_config = _build_db_cache_config(cache_config) + # The Fn_compute_blocks = 2 override is the most important decision here, + # and the rationale (quality degradation at Fn=1) only lives in flux2_klein. + db_cache_config.Fn_compute_blocks = 2 + + calibrator = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Build ParamsModifier for transformer + modifier = ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator, + ) + + logger.info( + f"Enabling cache-dit on Flux2-Klein transformer with BlockAdapter: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit using BlockAdapter for transformer + cache_dit.enable_cache( + ( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[ + pipeline.transformer.transformer_blocks, + pipeline.transformer.single_transformer_blocks, + ], + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2], + params_modifiers=[modifier], + ) + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The FLUX.2-klein-4B pipeline instance. + num_inference_steps: New number of inference steps. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + Fn_compute_blocks=db_cache_config.Fn_compute_blocks, + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for StableDiffusion3Pipeline. @@ -859,6 +937,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "Wan22I2VPipeline": enable_cache_for_wan22, "Wan22TI2VPipeline": enable_cache_for_wan22, "FluxPipeline": enable_cache_for_flux, + "Flux2KleinPipeline": enable_cache_for_flux2_klein, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, "StableDiffusion3Pipeline": enable_cache_for_sd3, diff --git a/vllm_omni/diffusion/models/flux2/__init__.py b/vllm_omni/diffusion/models/flux2/__init__.py new file mode 100644 index 00000000000..bce893e69d9 --- /dev/null +++ b/vllm_omni/diffusion/models/flux2/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Flux2 diffusion model components.""" + +from vllm_omni.diffusion.models.flux2.flux2_transformer import ( + Flux2Transformer2DModel, +) +from vllm_omni.diffusion.models.flux2.pipeline_flux2 import ( + Flux2Pipeline, + get_flux2_post_process_func, +) + +__all__ = [ + "Flux2Pipeline", + "Flux2Transformer2DModel", + "get_flux2_post_process_func", +] diff --git a/vllm_omni/diffusion/models/flux2/flux2_transformer.py b/vllm_omni/diffusion/models/flux2/flux2_transformer.py new file mode 100644 index 00000000000..478d6cbaf7a --- /dev/null +++ b/vllm_omni/diffusion/models/flux2/flux2_transformer.py @@ -0,0 +1,764 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from types import SimpleNamespace +from typing import Any + +import torch +from diffusers.models.embeddings import ( + TimestepEmbedding, + Timesteps, + get_1d_rotary_pos_embed, +) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.utils import is_torch_npu_available +from torch import nn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + + +class Flux2SwiGLU(nn.Module): + """SwiGLU activation used by Flux2.""" + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: float = 3.0, + inner_dim: int | None = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + self.linear_in = MergedColumnParallelLinear( + dim, + [inner_dim, inner_dim], + bias=bias, + return_bias=False, + ) + self.act_fn = Flux2SwiGLU() + self.linear_out = RowParallelLinear( + inner_dim, + dim_out, + bias=bias, + input_is_parallel=True, + return_bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + return self.linear_out(x) + + +class Flux2Attention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dropout = dropout + self.added_kv_proj_dim = added_kv_proj_dim + + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=bias, + ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + self.inner_dim, + self.out_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ), + nn.Dropout(dropout), + ] + ) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + bias=added_proj_bias, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + self.to_add_out = RowParallelLinear( + self.inner_dim, + query_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.query_num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.kv_num_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + qkv, _ = self.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = encoder_qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (self.query_num_heads, -1)) + key = key.unflatten(-1, (self.kv_num_heads, -1)) + value = value.unflatten(-1, (self.kv_num_heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (self.add_query_num_heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.add_kv_num_heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.add_kv_num_heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + if encoder_hidden_states is not None: + context_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [context_len, hidden_states.shape[1] - context_len], + dim=1, + ) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + return hidden_states + + +class Flux2ParallelSelfAttention(nn.Module): + """ + Parallel attention block that fuses QKV projections with MLP input projections. + """ + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + self.to_qkv_mlp_proj = ColumnParallelLinear( + self.query_dim, + self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, + bias=bias, + gather_output=True, + ) + self.mlp_act_fn = Flux2SwiGLU() + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = ColumnParallelLinear( + self.inner_dim + self.mlp_hidden_dim, + self.out_dim, + bias=out_bias, + gather_output=True, + ) + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states, _ = self.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, + [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor], + dim=-1, + ) + + query, key, value = qkv.chunk(3, dim=-1) + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + attn_output = self.attn(query, key, value, attn_metadata) + attn_output = attn_output.flatten(2, 3).to(query.dtype) + + mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1) + hidden_states, _ = self.to_out(hidden_states) + return hidden_states + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + split_hidden_states: bool = False, + text_seq_len: int | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + return hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa + + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # Expected ids shape: [S, len(self.axes_dim)] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] + for i in range(len(self.axes_dim)): + freqs_cis = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + use_real=False, + freqs_dtype=freqs_dtype, + ) + cos_out.append(freqs_cis.real) + sin_out.append(freqs_cis.imag) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): + super().__init__() + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + else: + self.guidance_embedder = None + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor | None) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) + + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) + return timesteps_emb + guidance_emb + return timesteps_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2Transformer2DModel(nn.Module): + """ + The Transformer model introduced in Flux 2. + """ + + _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } + + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: int | None = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + guidance_embeds: bool = True, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.config = SimpleNamespace( + patch_size=patch_size, + in_channels=in_channels, + out_channels=self.out_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + timestep_guidance_channels=timestep_guidance_channels, + mlp_ratio=mlp_ratio, + axes_dims_rope=axes_dims_rope, + rope_theta=rope_theta, + eps=eps, + guidance_embeds=guidance_embeds, + ) + + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, + ) + + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + joint_attention_kwargs = joint_attention_kwargs or {} + + num_txt_tokens = encoder_hidden_states.shape[1] + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + if is_torch_npu_available(): + freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) + image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) + freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) + text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) + else: + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = hidden_states[:, num_txt_tokens:, ...] + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "to_qkvkv_mlp_proj" in name: + name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") + if "to_qkv_mlp_proj" in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py new file mode 100644 index 00000000000..e057c985caa --- /dev/null +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -0,0 +1,1078 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +import json +import logging +import math +import os +from collections.abc import Callable, Iterable +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.flux2.pipeline_flux2 import UPSAMPLING_MAX_IMAGE_SIZE +from diffusers.pipelines.flux2.system_messages import ( + SYSTEM_MESSAGE, + SYSTEM_MESSAGE_UPSAMPLING_I2I, + SYSTEM_MESSAGE_UPSAMPLING_T2I, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, PixtralProcessor +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +logger = logging.getLogger(__name__) + + +class Flux2ImageProcessor(VaeImageProcessor): + """Image processor to preprocess the reference image for Flux2.""" + + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_normalize: bool = True, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_normalize=do_normalize, + do_convert_rgb=do_convert_rgb, + ) + + @staticmethod + def check_image_input( + image: PIL.Image.Image, + max_aspect_ratio: int = 8, + min_side_length: int = 64, + max_area: int = 1024 * 1024, + ) -> PIL.Image.Image: + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") + + width, height = image.size + if width < min_side_length or height < min_side_length: + raise ValueError(f"Image too small: {width}x{height}. Both dimensions must be at least {min_side_length}px") + + aspect_ratio = max(width / height, height / width) + if aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Aspect ratio too extreme: {width}x{height} (ratio: {aspect_ratio:.1f}:1). " + f"Maximum allowed ratio is {max_aspect_ratio}:1" + ) + + if width * height > max_area: + logger.warning("Image area exceeds recommended maximum; resizing will be applied.") + + return image + + @staticmethod + def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + scale = math.sqrt(target_area / (image_width * image_height)) + width = int(image_width * scale) + height = int(image_height * scale) + return image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + @staticmethod + def _resize_if_exceeds_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + if image_width * image_height <= target_area: + return image + return Flux2ImageProcessor._resize_to_target_area(image, target_area) + + def _resize_and_crop(self, image: PIL.Image.Image, width: int, height: int) -> PIL.Image.Image: + image_width, image_height = image.size + left = (image_width - width) // 2 + top = (image_height - height) // 2 + right = left + width + bottom = top + height + return image.crop((left, top, right, bottom)) + + @staticmethod + def concatenate_images(images: list[PIL.Image.Image]) -> PIL.Image.Image: + if len(images) == 1: + return images[0].copy() + + images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] + total_width = sum(img.width for img in images) + max_height = max(img.height for img in images) + background_color = (255, 255, 255) + new_img = PIL.Image.new("RGB", (total_width, max_height), background_color) + + x_offset = 0 + for img in images: + y_offset = (max_height - img.height) // 2 + new_img.paste(img, (x_offset, y_offset)) + x_offset += img.width + + return new_img + + +def get_flux2_post_process_func( + od_config: OmniDiffusionConfig, +): + if od_config.output_type == "latent": + return lambda x: x + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func(images: torch.Tensor): + return image_processor.postprocess(images) + + return post_process_func + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 +def format_input( + prompts: list[str], + system_message: str = SYSTEM_MESSAGE, + images: list[PIL.Image.Image] | list[list[PIL.Image.Image]] = None, +): + """ + Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images + to the input. + + Args: + prompts: List of text prompts + system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) + images (optional): List of images to add to the input. + + Returns: + List of conversations, where each conversation is a list of message dicts + """ + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + if images is None or len(images) == 0: + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + else: + assert len(images) == len(prompts), "Number of images must match number of prompts" + messages = [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + ] + for _ in cleaned_txt + ] + + for i, (el, images) in enumerate(zip(messages, images)): + # optionally add the images per batch element. + if images is not None: + el.append( + { + "role": "user", + "content": [{"type": "image", "image": image_obj} for image_obj in images], + } + ) + # add the text. + el.append( + { + "role": "user", + "content": [{"type": "text", "text": cleaned_txt[i]}], + } + ) + + return messages + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19 +def _validate_and_process_images( + images: list[list[PIL.Image.Image]] | list[PIL.Image.Image], + image_processor: Flux2ImageProcessor, + upsampling_max_image_size: int, +) -> list[list[PIL.Image.Image]]: + # Simple validation: ensure it's a list of PIL images or list of lists of PIL images + if not images: + return [] + + # Check if it's a list of lists or a list of images + if isinstance(images[0], PIL.Image.Image): + # It's a list of images, convert to list of lists + images = [[im] for im in images] + + # potentially concatenate multiple images to reduce the size + images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images] + + # cap the pixels + images = [ + [image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i] + for img_i in images + ] + return images + + +# Taken from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator = None, sample_mode: str = "sample"): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2Pipeline(nn.Module, CFGParallelMixin, SupportImageInput): + """Flux2 pipeline for text-to-image generation.""" + + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + support_image_input = True + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self._execution_device = get_local_device() + model = od_config.model + # Check if model is a local path + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.text_encoder = Mistral3ForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.tokenizer = PixtralProcessor.from_pretrained( + model, subfolder="tokenizer", local_files_only=local_files_only + ) + self.vae = AutoencoderKLFlux2.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self._execution_device + ) + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel) + self.transformer = Flux2Transformer2DModel(**transformer_kwargs) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + self.system_message = SYSTEM_MESSAGE + self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I + self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I + self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE + + self._guidance_scale = None + self._attention_kwargs = None + self._num_timesteps = None + self._current_timestep = None + self._interrupt = False + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + system_message: str = SYSTEM_MESSAGE, + hidden_states_layers: list[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_positions = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_positions) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + layer_ids = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, layer_ids) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def upsample_prompt( + self, + prompt: str | list[str], + images: list[PIL.Image.Image] | list[list[PIL.Image.Image]] = None, + temperature: float = 0.15, + device: torch.device = None, + ) -> list[str]: + prompt = [prompt] if isinstance(prompt, str) else prompt + device = self.text_encoder.device if device is None else device + + # Set system message based on whether images are provided + if images is None or len(images) == 0 or images[0] is None: + system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I + else: + system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I + + # Validate and process the input images + if images: + images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size) + + # Format input messages + messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) + + # Process all messages at once + # with image processing a too short max length can throw an error in here. + inputs = self.tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=2048, + ) + + # Move to device + inputs["input_ids"] = inputs["input_ids"].to(device) + inputs["attention_mask"] = inputs["attention_mask"].to(device) + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype) + + # Generate text using the model's generate method + generated_ids = self.text_encoder.generate( + **inputs, + max_new_tokens=512, + do_sample=True, + temperature=temperature, + use_cache=True, + ) + + # Decode only the newly generated tokens (skip input tokens) + # Extract only the generated portion + input_length = inputs["input_ids"].shape[1] + generated_tokens = generated_ids[:, input_length:] + + upsampled_prompt = self.tokenizer.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return upsampled_prompt + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int, ...] = (10, 20, 30), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + "`height` and `width` have to be divisible by %s but are %s and %s. " + "Dimensions will be resized accordingly", + self.vae_scale_factor * 2, + height, + width, + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, " + f"but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def forward( + self, + req: OmniDiffusionRequest, + image: PIL.Image.Image | list[PIL.Image.Image] | None = None, + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int, ...] = (10, 20, 30), + caption_upsample_temperature: float = None, + ) -> DiffusionOutput: + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + pass # use image from param list + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers) + + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at list one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + torch.stack(req_negative_prompt_embeds) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + if caption_upsample_temperature: + prompt = self.upsample_prompt(prompt, images=image, temperature=caption_upsample_temperature, device=device) + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + self._current_timestep = None + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + if output_type == "latent": + image = latents + else: + if latents.dtype != self.vae.dtype: + latents = latents.to(self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index ed99b1473b3..b88c237fbd2 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -110,6 +110,11 @@ "pipeline_omnigen2", "OmniGen2Pipeline", ), + "Flux2Pipeline": ( + "flux2", + "pipeline_flux2", + "Flux2Pipeline", + ), } @@ -296,6 +301,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "NextStep11Pipeline": "get_nextstep11_post_process_func", "FluxPipeline": "get_flux_post_process_func", "OmniGen2Pipeline": "get_omnigen2_post_process_func", + "Flux2Pipeline": "get_flux2_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = {