diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index bd2a58799f9..e9d65142945 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -35,7 +35,7 @@ The following table shows which models are currently supported by parallelism me | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | ✅ | N/A | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ | N/A | -| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ✅ | ✅ | ❌ | ✅ | ❌ | N/A | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | N/A | | **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | diff --git a/vllm_omni/diffusion/layers/rope.py b/vllm_omni/diffusion/layers/rope.py index 461c25652e4..65d37d0b017 100644 --- a/vllm_omni/diffusion/layers/rope.py +++ b/vllm_omni/diffusion/layers/rope.py @@ -157,3 +157,29 @@ def forward_native( sin, interleaved=self.interleaved, ) + + +def apply_rope_to_qk( + rope: RotaryEmbedding, + query: torch.Tensor, + key: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary positional embeddings to query and key tensors. + + Args: + rope: RotaryEmbedding instance for applying position embeddings + query: Query tensor [B, S, H, D] + key: Key tensor [B, S, H, D] + image_rotary_emb: Tuple of (cos, sin) tensors or None + + Returns: + Tuple of (query, key) with RoPE applied if rotary embeddings provided + """ + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = rope(query, cos, sin) + key = rope(key, cos, sin) + return query, key diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py index db6f0d34ece..2979fd4f65a 100644 --- a/vllm_omni/diffusion/models/flux/flux_transformer.py +++ b/vllm_omni/diffusion/models/flux/flux_transformer.py @@ -30,7 +30,7 @@ from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import OmniDiffusionConfig -from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.layers.rope import RotaryEmbedding, apply_rope_to_qk logger = init_logger(__name__) @@ -224,12 +224,7 @@ def forward( key = torch.cat([encoder_key, key], dim=1) value = torch.cat([encoder_value, value], dim=1) - if image_rotary_emb is not None: - cos, sin = image_rotary_emb # [S, D/2] - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) + query, key = apply_rope_to_qk(self.rope, query, key, image_rotary_emb) # [S, D/2] hidden_states = self.attn( query, diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index 2499b9c1692..c10f06751dd 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -28,6 +28,7 @@ ) from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.normalization import AdaLayerNormContinuous +from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -39,8 +40,15 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention -from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, +) +from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.diffusion.layers.rope import RotaryEmbedding, apply_rope_to_qk +logger = init_logger(__name__) if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -98,6 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Flux2Attention(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, query_dim: int, heads: int = 8, dim_head: int = 64, @@ -112,6 +121,7 @@ def __init__( quant_config: "QuantizationConfig | None" = None, ): super().__init__() + self.parallel_config = parallel_config self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim @@ -208,33 +218,87 @@ def forward( encoder_query = self.norm_added_q(encoder_query) encoder_key = self.norm_added_k(encoder_key) - query = torch.cat([encoder_query, query], dim=1) - key = torch.cat([encoder_key, key], dim=1) - value = torch.cat([encoder_value, value], dim=1) - - if image_rotary_emb is not None: - cos, sin = image_rotary_emb - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) - - attn_metadata = None - if attention_mask is not None: - if attention_mask.dim() == 3: - attention_mask = attention_mask.unsqueeze(1) - attn_metadata = AttentionMetadata(attn_mask=attention_mask) - - hidden_states = self.attn(query, key, value, attn_metadata) - hidden_states = hidden_states.flatten(2, 3).to(query.dtype) - - if encoder_hidden_states is not None: - context_len = encoder_hidden_states.shape[1] - encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( - [context_len, hidden_states.shape[1] - context_len], - dim=1, - ) - encoder_hidden_states = self.to_add_out(encoder_hidden_states) + sp_size = self.parallel_config.sequence_parallel_size + forward_ctx = get_forward_context() + use_sp_joint_attention = sp_size is not None and sp_size > 1 and not forward_ctx.split_text_embed_in_sp + + if use_sp_joint_attention and image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + txt_len = encoder_query.shape[1] + txt_cos, img_cos = cos[:txt_len], cos[txt_len:] + txt_sin, img_sin = sin[:txt_len], sin[txt_len:] + query = self.rope(query, img_cos, img_sin) + key = self.rope(key, img_cos, img_sin) + encoder_query = self.rope(encoder_query, txt_cos, txt_sin) + encoder_key = self.rope(encoder_key, txt_cos, txt_sin) + + attn_metadata = AttentionMetadata( + joint_query=encoder_query, + joint_key=encoder_key, + joint_value=encoder_value, + joint_strategy="front", + ) + hidden_states_mask: torch.Tensor | None = kwargs.get("hidden_states_mask", None) + encoder_hidden_states_mask: torch.Tensor | None = kwargs.get("encoder_hidden_states_mask", None) + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + if encoder_hidden_states_mask is not None: + attn_metadata.joint_attn_mask = encoder_hidden_states_mask + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + txt_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [txt_len, hidden_states.shape[1] - txt_len], + dim=1, + ) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + else: + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + context_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [context_len, hidden_states.shape[1] - context_len], + dim=1, + ) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + else: + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) hidden_states = self.to_out[0](hidden_states) hidden_states = self.to_out[1](hidden_states) @@ -251,6 +315,7 @@ class Flux2ParallelSelfAttention(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, query_dim: int, heads: int = 8, dim_head: int = 64, @@ -265,6 +330,7 @@ def __init__( quant_config: "QuantizationConfig | None" = None, ): super().__init__() + self.parallel_config = parallel_config self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim @@ -325,20 +391,57 @@ def forward( query = self.norm_q(query) key = self.norm_k(key) - if image_rotary_emb is not None: + sp_size = self.parallel_config.sequence_parallel_size + forward_ctx = get_forward_context() + text_seq_len = kwargs.get("text_seq_len", None) + use_sp_single_stream = ( + sp_size is not None and sp_size > 1 and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None + ) + + if use_sp_single_stream and image_rotary_emb is not None: cos, sin = image_rotary_emb cos = cos.to(query.dtype) sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) + txt_cos, img_cos = cos[:text_seq_len], cos[text_seq_len:] + txt_sin, img_sin = sin[:text_seq_len], sin[text_seq_len:] + + img_query = query[:, text_seq_len:] + img_key = key[:, text_seq_len:] + img_value = value[:, text_seq_len:] + text_query = query[:, :text_seq_len] + text_key = key[:, :text_seq_len] + text_value = value[:, :text_seq_len] + + img_query = self.rope(img_query, img_cos, img_sin) + img_key = self.rope(img_key, img_cos, img_sin) + text_query = self.rope(text_query, txt_cos, txt_sin) + text_key = self.rope(text_key, txt_cos, txt_sin) + + attn_metadata = AttentionMetadata( + joint_query=text_query, + joint_key=text_key, + joint_value=text_value, + joint_strategy="front", + ) + hidden_states_mask: torch.Tensor | None = kwargs.get("hidden_states_mask", None) + encoder_hidden_states_mask: torch.Tensor | None = kwargs.get("encoder_hidden_states_mask", None) + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + if encoder_hidden_states_mask is not None: + attn_metadata.joint_attn_mask = encoder_hidden_states_mask + + attn_output = self.attn(img_query, img_key, img_value, attn_metadata) + else: + query, key = apply_rope_to_qk(self.rope, query, key, image_rotary_emb) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) - attn_metadata = None - if attention_mask is not None: - if attention_mask.dim() == 3: - attention_mask = attention_mask.unsqueeze(1) - attn_metadata = AttentionMetadata(attn_mask=attention_mask) + attn_output = self.attn(query, key, value, attn_metadata) - attn_output = self.attn(query, key, value, attn_metadata) attn_output = attn_output.flatten(2, 3).to(query.dtype) mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) @@ -350,6 +453,7 @@ def forward( class Flux2SingleTransformerBlock(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, dim: int, num_attention_heads: int, attention_head_dim: int, @@ -361,6 +465,7 @@ def __init__( super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.attn = Flux2ParallelSelfAttention( + parallel_config=parallel_config, query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, @@ -383,6 +488,13 @@ def forward( split_hidden_states: bool = False, text_seq_len: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for Flux2SingleTransformerBlock with SP support. + + In SP mode: image hidden_states is chunked (B, img_len/SP, D), + text encoder_hidden_states is full (B, txt_len, D). + The block concatenates them for joint attention. + """ if encoder_hidden_states is not None: text_seq_len = encoder_hidden_states.shape[1] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -412,6 +524,7 @@ def forward( class Flux2TransformerBlock(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, dim: int, num_attention_heads: int, attention_head_dim: int, @@ -425,6 +538,7 @@ def __init__( self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.attn = Flux2Attention( + parallel_config=parallel_config, query_dim=dim, added_kv_proj_dim=dim, dim_head=attention_head_dim, @@ -525,6 +639,47 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return freqs_cos, freqs_sin +class Flux2RopePrepare(nn.Module): + """Prepares RoPE embeddings for sequence parallel. + + This module encapsulates the RoPE computation for Flux.2-klein. + For dual-stream attention, text components (outputs 0, 1) are replicated + across SP ranks, while image components (outputs 2, 3) are sharded. + + NOTE: The hidden_states projection is handled separately in forward() + so that _sp_plan can shard it at the root level. + """ + + def __init__(self, pos_embed: Flux2PosEmbed): + super().__init__() + self.pos_embed = pos_embed + + def forward( + self, + img_ids: torch.Tensor, + txt_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute RoPE embeddings for text and image sequences. + + Args: + img_ids: Image position IDs (img_seq_len, n_axes) + txt_ids: Text position IDs (txt_seq_len, n_axes) + + Returns: + Tuple of cosine / sine components for text & image + in the order: (txt_cos, txt_sin, img_cos, img_sin) + + NOTE: careful about output orders if this is refactored in the + future; we need to match the _sp_plan indices, since text + components (0 & 1) need to be replicated across SP ranks, + while image components (2 & 3) must be sharded. + """ + img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) + txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) + return txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin + + class Flux2TimestepGuidanceEmbeddings(nn.Module): def __init__( self, @@ -580,10 +735,24 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, class Flux2Transformer2DModel(nn.Module): """ The Transformer model introduced in Flux 2. + + Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig. """ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + _sp_plan = { + "": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True), + }, + "rope_prepare": { + 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + }, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + """SP plan: shard hidden_states at root level, shard img_freqs at rope_prepare, gather output at proj_out.""" + def __init__( self, patch_size: int = 1, @@ -600,6 +769,7 @@ def __init__( rope_theta: int = 2000, eps: float = 1e-6, guidance_embeds: bool = True, + od_config: OmniDiffusionConfig = None, quant_config: "QuantizationConfig | None" = None, ): super().__init__() @@ -622,6 +792,13 @@ def __init__( guidance_embeds=guidance_embeds, ) + if od_config is not None: + self.parallel_config = od_config.parallel_config + else: + from vllm_omni.diffusion.data import DiffusionParallelConfig + + self.parallel_config = DiffusionParallelConfig() + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope)) self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( in_channels=timestep_guidance_channels, @@ -637,9 +814,12 @@ def __init__( self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + self.rope_prepare = Flux2RopePrepare(self.pos_embed) + self.transformer_blocks = nn.ModuleList( [ Flux2TransformerBlock( + parallel_config=self.parallel_config, dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -655,6 +835,7 @@ def __init__( self.single_transformer_blocks = nn.ModuleList( [ Flux2SingleTransformerBlock( + parallel_config=self.parallel_config, dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -695,6 +876,10 @@ def forward( num_txt_tokens = encoder_hidden_states.shape[1] + sp_size = self.parallel_config.sequence_parallel_size + if sp_size is not None and sp_size > 1: + get_forward_context().split_text_embed_in_sp = False + timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -705,21 +890,44 @@ def forward( double_stream_mod_txt = self.double_stream_modulation_txt(temb) single_stream_mod = self.single_stream_modulation(temb)[0] - hidden_states = self.x_embedder(hidden_states) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - if img_ids.ndim == 3: img_ids = img_ids[0] if txt_ids.ndim == 3: txt_ids = txt_ids[0] - image_rotary_emb = self.pos_embed(img_ids) - text_rotary_emb = self.pos_embed(txt_ids) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin = self.rope_prepare(img_ids, txt_ids) + concat_rotary_emb = ( - torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), - torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), + torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), ) + # Create separate masks for image and text portions for Ulysses SP joint attention + hidden_states_mask = None + encoder_hidden_states_mask = None + ctx = get_forward_context() + if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: + batch_size = hidden_states.shape[0] + img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size + + hidden_states_mask = torch.ones( + batch_size, + img_padded_seq_len, + dtype=torch.bool, + device=hidden_states.device, + ) + hidden_states_mask[:, ctx.sp_original_seq_len :] = False + if hidden_states_mask.all(): + hidden_states_mask = None + + if hidden_states_mask is not None: + joint_attention_kwargs["hidden_states_mask"] = hidden_states_mask + if encoder_hidden_states_mask is not None: + joint_attention_kwargs["encoder_hidden_states_mask"] = encoder_hidden_states_mask + for block in self.transformer_blocks: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, @@ -739,6 +947,7 @@ def forward( temb_mod_params=single_stream_mod, image_rotary_emb=concat_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, + text_seq_len=num_txt_tokens, ) hidden_states = hidden_states[:, num_txt_tokens:, ...] diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 3e61610b125..b68b6f70d94 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -49,7 +49,7 @@ get_forward_context, is_forward_context_available, ) -from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.layers.rope import RotaryEmbedding, apply_rope_to_qk ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 @@ -329,10 +329,7 @@ def forward( query = self.norm_q(query) key = self.norm_k(key) - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) + query, key = apply_rope_to_qk(self.rope, query, key, (cos, sin)) # Cast to correct dtype dtype = query.dtype query, key = query.to(dtype), key.to(dtype)