From e8fb7395f796e9d1e60424dc096cf8c965b8cd96 Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 6 Feb 2026 17:16:54 +0800 Subject: [PATCH 01/11] [Feat] support SP for FLUX.2-klein Signed-off-by: Lancer --- .../flux2_klein/flux2_klein_transformer.py | 80 ++++++++++++++++++- .../flux2_klein/pipeline_flux2_klein.py | 2 +- 2 files changed, 77 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index ee10d2e0e4d..8e1bf3c7af5 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -28,6 +28,7 @@ ) from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.normalization import AdaLayerNormContinuous +from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -39,7 +40,17 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) +from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) class Flux2SwiGLU(nn.Module): @@ -334,6 +345,12 @@ def forward( class Flux2SingleTransformerBlock(nn.Module): + """ + Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support. + + SP handling is delegated to Flux2Attention via the forward context. + """ + def __init__( self, dim: int, @@ -367,6 +384,13 @@ def forward( split_hidden_states: bool = False, text_seq_len: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for Flux2SingleTransformerBlock with SP support. + + In SP mode: image hidden_states is chunked (B, img_len/SP, D), + text encoder_hidden_states is full (B, txt_len, D). + The block concatenates them for joint attention. + """ if encoder_hidden_states is not None: text_seq_len = encoder_hidden_states.shape[1] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -556,6 +580,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, class Flux2Transformer2DModel(nn.Module): """ The Transformer model introduced in Flux 2. + + Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig. """ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] @@ -580,6 +606,7 @@ def __init__( rope_theta: int = 2000, eps: float = 1e-6, guidance_embeds: bool = True, + od_config: OmniDiffusionConfig = None, ): super().__init__() self.out_channels = out_channels or in_channels @@ -601,6 +628,13 @@ def __init__( guidance_embeds=guidance_embeds, ) + if od_config is not None: + self.parallel_config = od_config.parallel_config + else: + from vllm_omni.diffusion.data import DiffusionParallelConfig + + self.parallel_config = DiffusionParallelConfig() + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope)) self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( in_channels=timestep_guidance_channels, @@ -672,6 +706,25 @@ def forward( num_txt_tokens = encoder_hidden_states.shape[1] + sp_size = self.parallel_config.sequence_parallel_size + get_forward_context().sequence_parallel_size = sp_size + if sp_size > 1: + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + original_shape = hidden_states.shape + hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] + get_forward_context().split_text_embed_in_sp = False + if not hasattr(self, "_sp_forward_logged"): + self._sp_forward_logged = True + logger.info( + f"[Flux2 Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " + f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" + ) + else: + if not hasattr(self, "_sp_forward_logged"): + self._sp_forward_logged = True + logger.info(f"[Flux2 Transformer] SP disabled: sp_size={sp_size}") + timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -690,11 +743,27 @@ def forward( if txt_ids.ndim == 3: txt_ids = txt_ids[0] - image_rotary_emb = self.pos_embed(img_ids) - text_rotary_emb = self.pos_embed(txt_ids) + if current_omni_platform.is_npu(): + img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids.cpu()) + img_freqs_cos, img_freqs_sin = img_freqs_cos.npu(), img_freqs_sin.npu() + txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids.cpu()) + txt_freqs_cos, txt_freqs_sin = txt_freqs_cos.npu(), txt_freqs_sin.npu() + else: + img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) + txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) + + if sp_size > 1: + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank] + img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank] + if get_forward_context().split_text_embed_in_sp: + txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank] + txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank] + concat_rotary_emb = ( - torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), - torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), + torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), ) for block in self.transformer_blocks: @@ -722,6 +791,9 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) + if self.parallel_config.sequence_parallel_size > 1: + output = get_sp_group().all_gather(output, dim=1) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index e1ef706c3f5..0de6e11a650 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -230,7 +230,7 @@ def __init__( ).to(self._execution_device) transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel) - self.transformer = Flux2Transformer2DModel(**transformer_kwargs) + self.transformer = Flux2Transformer2DModel(od_config=od_config, **transformer_kwargs) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) From f2daa529db26e31a1626a030d24acd86af55ab63 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sun, 8 Feb 2026 01:05:06 +0800 Subject: [PATCH 02/11] upd Signed-off-by: Lancer --- .../models/flux2_klein/flux2_klein_transformer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index 8e1bf3c7af5..ed084c7586f 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -46,6 +46,7 @@ get_sequence_parallel_world_size, get_sp_group, ) +from vllm_omni.diffusion.distributed.sp_sharding import sp_shard_with_padding from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.layers.rope import RotaryEmbedding from vllm_omni.platforms import current_omni_platform @@ -708,11 +709,12 @@ def forward( sp_size = self.parallel_config.sequence_parallel_size get_forward_context().sequence_parallel_size = sp_size + sp_pad_size = 0 if sp_size > 1: sp_world_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() original_shape = hidden_states.shape - hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] + hidden_states, sp_pad_size = sp_shard_with_padding(hidden_states, dim=1) get_forward_context().split_text_embed_in_sp = False if not hasattr(self, "_sp_forward_logged"): self._sp_forward_logged = True @@ -755,11 +757,11 @@ def forward( if sp_size > 1: sp_world_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() - img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank] - img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank] + img_freqs_cos, _ = sp_shard_with_padding(img_freqs_cos, dim=0) + img_freqs_sin, _ = sp_shard_with_padding(img_freqs_sin, dim=0) if get_forward_context().split_text_embed_in_sp: - txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank] - txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank] + txt_freqs_cos, _ = sp_shard_with_padding(txt_freqs_cos, dim=0) + txt_freqs_sin, _ = sp_shard_with_padding(txt_freqs_sin, dim=0) concat_rotary_emb = ( torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), @@ -793,6 +795,8 @@ def forward( if self.parallel_config.sequence_parallel_size > 1: output = get_sp_group().all_gather(output, dim=1) + if sp_pad_size > 0: + output = output[:, :-sp_pad_size, ...] if not return_dict: return (output,) From 225ec8a6cb8d60ec769bf12369846cf948aef19b Mon Sep 17 00:00:00 2001 From: Lancer Date: Sun, 8 Feb 2026 01:14:49 +0800 Subject: [PATCH 03/11] upd Signed-off-by: Lancer --- .../models/flux2_klein/flux2_klein_transformer.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index ed084c7586f..b1536713b58 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -49,7 +49,6 @@ from vllm_omni.diffusion.distributed.sp_sharding import sp_shard_with_padding from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.layers.rope import RotaryEmbedding -from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) @@ -745,14 +744,8 @@ def forward( if txt_ids.ndim == 3: txt_ids = txt_ids[0] - if current_omni_platform.is_npu(): - img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids.cpu()) - img_freqs_cos, img_freqs_sin = img_freqs_cos.npu(), img_freqs_sin.npu() - txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids.cpu()) - txt_freqs_cos, txt_freqs_sin = txt_freqs_cos.npu(), txt_freqs_sin.npu() - else: - img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) - txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) + img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) + txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) if sp_size > 1: sp_world_size = get_sequence_parallel_world_size() From c57c8666fc868e8d531e93f5267e26a734539ede Mon Sep 17 00:00:00 2001 From: Lancer Date: Sun, 8 Feb 2026 08:54:02 +0800 Subject: [PATCH 04/11] upd Signed-off-by: Lancer --- docs/user_guide/diffusion/parallelism_acceleration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 6e2c18d64c8..38127ce63b7 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -29,7 +29,7 @@ The following table shows which models are currently supported by parallelism me | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | -| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ✅ | ✅ | ❌ | ✅ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ❌ | ✅ | !!! note "TP Limitations for Diffusion Models" From b33ce101fa8b23ff6b9b66ae0be09228d9b7ba27 Mon Sep 17 00:00:00 2001 From: Lancer Date: Mon, 9 Feb 2026 22:54:09 +0800 Subject: [PATCH 05/11] upd Signed-off-by: Lancer --- .../flux2_klein/flux2_klein_transformer.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index b1536713b58..eeb6eb7be9c 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -42,8 +42,6 @@ from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.distributed.parallel_state import ( - get_sequence_parallel_rank, - get_sequence_parallel_world_size, get_sp_group, ) from vllm_omni.diffusion.distributed.sp_sharding import sp_shard_with_padding @@ -707,24 +705,10 @@ def forward( num_txt_tokens = encoder_hidden_states.shape[1] sp_size = self.parallel_config.sequence_parallel_size - get_forward_context().sequence_parallel_size = sp_size sp_pad_size = 0 if sp_size > 1: - sp_world_size = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() - original_shape = hidden_states.shape hidden_states, sp_pad_size = sp_shard_with_padding(hidden_states, dim=1) get_forward_context().split_text_embed_in_sp = False - if not hasattr(self, "_sp_forward_logged"): - self._sp_forward_logged = True - logger.info( - f"[Flux2 Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " - f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" - ) - else: - if not hasattr(self, "_sp_forward_logged"): - self._sp_forward_logged = True - logger.info(f"[Flux2 Transformer] SP disabled: sp_size={sp_size}") timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: @@ -748,8 +732,6 @@ def forward( txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) if sp_size > 1: - sp_world_size = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() img_freqs_cos, _ = sp_shard_with_padding(img_freqs_cos, dim=0) img_freqs_sin, _ = sp_shard_with_padding(img_freqs_sin, dim=0) if get_forward_context().split_text_embed_in_sp: From 780ed851cac2a9061ecb3452ed004b6d51f4b0d9 Mon Sep 17 00:00:00 2001 From: Lancer Date: Tue, 3 Mar 2026 12:52:18 +0800 Subject: [PATCH 06/11] upd Signed-off-by: Lancer --- .../diffusion/parallelism_acceleration.md | 16 ---------------- .../flux2_klein/flux2_klein_transformer.py | 6 ------ 2 files changed, 22 deletions(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 77b85dae09d..b0225e2ffb9 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -22,21 +22,6 @@ The following table shows which models are currently supported by parallelism me ### ImageGen -<<<<<<< spforflux2klein -| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | -|--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ✅ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ✅ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | -| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | -| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ✅ | ✅ | ❌ | ✅ | -| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ❌ | ✅ | -======= | Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel | |--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:|:------------------:| | **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ✅ | ❌ | @@ -50,7 +35,6 @@ The following table shows which models are currently supported by parallelism me | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ | | **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | ->>>>>>> main !!! note "TP Limitations for Diffusion Models" We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index f5cd098efeb..1dbcfea3144 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -48,12 +48,9 @@ from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.layers.rope import RotaryEmbedding -<<<<<<< spforflux2klein logger = init_logger(__name__) -======= if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import QuantizationConfig ->>>>>>> main class Flux2SwiGLU(nn.Module): @@ -630,11 +627,8 @@ def __init__( rope_theta: int = 2000, eps: float = 1e-6, guidance_embeds: bool = True, -<<<<<<< spforflux2klein od_config: OmniDiffusionConfig = None, -======= quant_config: "QuantizationConfig | None" = None, ->>>>>>> main ): super().__init__() self.out_channels = out_channels or in_channels From 8a7e687cf939b3b86a05fa58c77388da0d0a6f4a Mon Sep 17 00:00:00 2001 From: Lancer Date: Mon, 9 Mar 2026 17:26:50 +0800 Subject: [PATCH 07/11] upd Signed-off-by: Lancer --- .../flux2_klein/flux2_klein_transformer.py | 260 +++++++++++++----- 1 file changed, 193 insertions(+), 67 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index 1dbcfea3144..ad4d7bc232b 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -41,10 +41,10 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_sp_group, +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, ) -from vllm_omni.diffusion.distributed.sp_sharding import sp_shard_with_padding from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.layers.rope import RotaryEmbedding @@ -216,33 +216,84 @@ def forward( encoder_query = self.norm_added_q(encoder_query) encoder_key = self.norm_added_k(encoder_key) - query = torch.cat([encoder_query, query], dim=1) - key = torch.cat([encoder_key, key], dim=1) - value = torch.cat([encoder_value, value], dim=1) - - if image_rotary_emb is not None: - cos, sin = image_rotary_emb - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) - - attn_metadata = None - if attention_mask is not None: - if attention_mask.dim() == 3: - attention_mask = attention_mask.unsqueeze(1) - attn_metadata = AttentionMetadata(attn_mask=attention_mask) - - hidden_states = self.attn(query, key, value, attn_metadata) - hidden_states = hidden_states.flatten(2, 3).to(query.dtype) - - if encoder_hidden_states is not None: - context_len = encoder_hidden_states.shape[1] - encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( - [context_len, hidden_states.shape[1] - context_len], - dim=1, - ) - encoder_hidden_states = self.to_add_out(encoder_hidden_states) + forward_ctx = get_forward_context() + use_sp_joint_attention = forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp + + if use_sp_joint_attention and image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + txt_len = encoder_query.shape[1] + txt_cos, img_cos = cos[:txt_len], cos[txt_len:] + txt_sin, img_sin = sin[:txt_len], sin[txt_len:] + query = self.rope(query, img_cos, img_sin) + key = self.rope(key, img_cos, img_sin) + encoder_query = self.rope(encoder_query, txt_cos, txt_sin) + encoder_key = self.rope(encoder_key, txt_cos, txt_sin) + + attn_metadata = AttentionMetadata( + joint_query=encoder_query, + joint_key=encoder_key, + joint_value=encoder_value, + joint_strategy="front", + ) + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata.attn_mask = attention_mask + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + txt_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [txt_len, hidden_states.shape[1] - txt_len], + dim=1, + ) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + else: + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + context_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [context_len, hidden_states.shape[1] - context_len], + dim=1, + ) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + else: + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) hidden_states = self.to_out[0](hidden_states) hidden_states = self.to_out[1](hidden_states) @@ -333,20 +384,59 @@ def forward( query = self.norm_q(query) key = self.norm_k(key) - if image_rotary_emb is not None: + forward_ctx = get_forward_context() + text_seq_len = kwargs.get("text_seq_len", None) + use_sp_single_stream = ( + forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None + ) + + if use_sp_single_stream and image_rotary_emb is not None: cos, sin = image_rotary_emb cos = cos.to(query.dtype) sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) + txt_cos, img_cos = cos[:text_seq_len], cos[text_seq_len:] + txt_sin, img_sin = sin[:text_seq_len], sin[text_seq_len:] + + img_query = query[:, text_seq_len:] + img_key = key[:, text_seq_len:] + img_value = value[:, text_seq_len:] + text_query = query[:, :text_seq_len] + text_key = key[:, :text_seq_len] + text_value = value[:, :text_seq_len] + + img_query = self.rope(img_query, img_cos, img_sin) + img_key = self.rope(img_key, img_cos, img_sin) + text_query = self.rope(text_query, txt_cos, txt_sin) + text_key = self.rope(text_key, txt_cos, txt_sin) + + attn_metadata = AttentionMetadata( + joint_query=text_query, + joint_key=text_key, + joint_value=text_value, + joint_strategy="front", + ) + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata.attn_mask = attention_mask + + attn_output = self.attn(img_query, img_key, img_value, attn_metadata) + else: + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) - attn_metadata = None - if attention_mask is not None: - if attention_mask.dim() == 3: - attention_mask = attention_mask.unsqueeze(1) - attn_metadata = AttentionMetadata(attn_mask=attention_mask) + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + attn_output = self.attn(query, key, value, attn_metadata) - attn_output = self.attn(query, key, value, attn_metadata) attn_output = attn_output.flatten(2, 3).to(query.dtype) mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) @@ -356,12 +446,6 @@ def forward( class Flux2SingleTransformerBlock(nn.Module): - """ - Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support. - - SP handling is delegated to Flux2Attention via the forward context. - """ - def __init__( self, dim: int, @@ -546,6 +630,32 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return freqs_cos, freqs_sin +class Flux2RopePrepare(nn.Module): + """Prepares hidden_states and RoPE embeddings for sequence parallel. + + This module encapsulates the input projection and RoPE computation for Flux.2-klein. + The key insight is that hidden_states and img_freqs must be sharded together + to maintain dimension alignment for RoPE computation in attention layers. + txt_freqs is kept replicated for dual-stream joint attention. + """ + + def __init__(self, x_embedder: nn.Linear, pos_embed: Flux2PosEmbed): + super().__init__() + self.x_embedder = x_embedder + self.pos_embed = pos_embed + + def forward( + self, + hidden_states: torch.Tensor, + img_ids: torch.Tensor, + txt_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = self.x_embedder(hidden_states) + img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) + txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) + return hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin + + class Flux2TimestepGuidanceEmbeddings(nn.Module): def __init__( self, @@ -611,6 +721,16 @@ class Flux2Transformer2DModel(nn.Module): "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], } + _sp_plan = { + "rope_prepare": { + 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True), + 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + 4: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + }, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + """SP plan: shard hidden_states/img_freqs at rope_prepare, gather output at proj_out.""" + def __init__( self, patch_size: int = 1, @@ -672,6 +792,8 @@ def __init__( self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + self.rope_prepare = Flux2RopePrepare(self.x_embedder, self.pos_embed) + self.transformer_blocks = nn.ModuleList( [ Flux2TransformerBlock( @@ -730,11 +852,7 @@ def forward( num_txt_tokens = encoder_hidden_states.shape[1] - sp_size = self.parallel_config.sequence_parallel_size - sp_pad_size = 0 - if sp_size > 1: - hidden_states, sp_pad_size = sp_shard_with_padding(hidden_states, dim=1) - get_forward_context().split_text_embed_in_sp = False + get_forward_context().split_text_embed_in_sp = False timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: @@ -746,29 +864,41 @@ def forward( double_stream_mod_txt = self.double_stream_modulation_txt(temb) single_stream_mod = self.single_stream_modulation(temb)[0] - hidden_states = self.x_embedder(hidden_states) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - if img_ids.ndim == 3: img_ids = img_ids[0] if txt_ids.ndim == 3: txt_ids = txt_ids[0] - img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) - txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) - - if sp_size > 1: - img_freqs_cos, _ = sp_shard_with_padding(img_freqs_cos, dim=0) - img_freqs_sin, _ = sp_shard_with_padding(img_freqs_sin, dim=0) - if get_forward_context().split_text_embed_in_sp: - txt_freqs_cos, _ = sp_shard_with_padding(txt_freqs_cos, dim=0) - txt_freqs_sin, _ = sp_shard_with_padding(txt_freqs_sin, dim=0) + hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin = self.rope_prepare( + hidden_states, img_ids, txt_ids + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) concat_rotary_emb = ( torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), ) + hidden_states_mask = None + ctx = get_forward_context() + if ctx.sp_active: + if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: + batch_size = hidden_states.shape[0] + img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size + full_seq_len = num_txt_tokens + img_padded_seq_len + hidden_states_mask = torch.ones( + batch_size, + full_seq_len, + dtype=torch.bool, + device=hidden_states.device, + ) + hidden_states_mask[:, num_txt_tokens + ctx.sp_original_seq_len :] = False + if hidden_states_mask.all(): + hidden_states_mask = None + + if hidden_states_mask is not None: + joint_attention_kwargs["attention_mask"] = hidden_states_mask + for block in self.transformer_blocks: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, @@ -788,17 +918,13 @@ def forward( temb_mod_params=single_stream_mod, image_rotary_emb=concat_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, + text_seq_len=num_txt_tokens, ) hidden_states = hidden_states[:, num_txt_tokens:, ...] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if self.parallel_config.sequence_parallel_size > 1: - output = get_sp_group().all_gather(output, dim=1) - if sp_pad_size > 0: - output = output[:, :-sp_pad_size, ...] - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) From 63c4658e2116a0e0f553633ce3fd89a22e636bb0 Mon Sep 17 00:00:00 2001 From: Lancer Date: Tue, 10 Mar 2026 07:48:35 +0800 Subject: [PATCH 08/11] upd Signed-off-by: Lancer --- vllm_omni/diffusion/layers/rope.py | 26 ++++++++++++++ .../diffusion/models/flux/flux_transformer.py | 9 ++--- .../flux2_klein/flux2_klein_transformer.py | 36 ++++++++----------- .../models/z_image/z_image_transformer.py | 7 ++-- 4 files changed, 45 insertions(+), 33 deletions(-) diff --git a/vllm_omni/diffusion/layers/rope.py b/vllm_omni/diffusion/layers/rope.py index 461c25652e4..65d37d0b017 100644 --- a/vllm_omni/diffusion/layers/rope.py +++ b/vllm_omni/diffusion/layers/rope.py @@ -157,3 +157,29 @@ def forward_native( sin, interleaved=self.interleaved, ) + + +def apply_rope_to_qk( + rope: RotaryEmbedding, + query: torch.Tensor, + key: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary positional embeddings to query and key tensors. + + Args: + rope: RotaryEmbedding instance for applying position embeddings + query: Query tensor [B, S, H, D] + key: Key tensor [B, S, H, D] + image_rotary_emb: Tuple of (cos, sin) tensors or None + + Returns: + Tuple of (query, key) with RoPE applied if rotary embeddings provided + """ + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = rope(query, cos, sin) + key = rope(key, cos, sin) + return query, key diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py index faf6d08d3a5..7613d8fa7d8 100644 --- a/vllm_omni/diffusion/models/flux/flux_transformer.py +++ b/vllm_omni/diffusion/models/flux/flux_transformer.py @@ -22,7 +22,7 @@ from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import OmniDiffusionConfig -from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.layers.rope import RotaryEmbedding, apply_rope_to_qk logger = init_logger(__name__) @@ -200,12 +200,7 @@ def forward( key = torch.cat([encoder_key, key], dim=1) value = torch.cat([encoder_value, value], dim=1) - if image_rotary_emb is not None: - cos, sin = image_rotary_emb # [S, D/2] - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) + query, key = apply_rope_to_qk(self.rope, query, key, image_rotary_emb) # [S, D/2] hidden_states = self.attn( query, diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index ad4d7bc232b..4cbddc32afc 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -46,7 +46,7 @@ SequenceParallelOutput, ) from vllm_omni.diffusion.forward_context import get_forward_context -from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.layers.rope import RotaryEmbedding, apply_rope_to_qk logger = init_logger(__name__) if TYPE_CHECKING: @@ -422,12 +422,7 @@ def forward( attn_output = self.attn(img_query, img_key, img_value, attn_metadata) else: - if image_rotary_emb is not None: - cos, sin = image_rotary_emb - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) + query, key = apply_rope_to_qk(self.rope, query, key, image_rotary_emb) # [S, D/2] attn_metadata = None if attention_mask is not None: @@ -881,20 +876,19 @@ def forward( hidden_states_mask = None ctx = get_forward_context() - if ctx.sp_active: - if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: - batch_size = hidden_states.shape[0] - img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size - full_seq_len = num_txt_tokens + img_padded_seq_len - hidden_states_mask = torch.ones( - batch_size, - full_seq_len, - dtype=torch.bool, - device=hidden_states.device, - ) - hidden_states_mask[:, num_txt_tokens + ctx.sp_original_seq_len :] = False - if hidden_states_mask.all(): - hidden_states_mask = None + if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: + batch_size = hidden_states.shape[0] + img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size + full_seq_len = num_txt_tokens + img_padded_seq_len + hidden_states_mask = torch.ones( + batch_size, + full_seq_len, + dtype=torch.bool, + device=hidden_states.device, + ) + hidden_states_mask[:, num_txt_tokens + ctx.sp_original_seq_len :] = False + if hidden_states_mask.all(): + hidden_states_mask = None if hidden_states_mask is not None: joint_attention_kwargs["attention_mask"] = hidden_states_mask diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 83d0ecb9b95..48ffa43ade3 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -49,7 +49,7 @@ get_forward_context, is_forward_context_available, ) -from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.layers.rope import RotaryEmbedding, apply_rope_to_qk ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 @@ -329,10 +329,7 @@ def forward( query = self.norm_q(query) key = self.norm_k(key) - cos = cos.to(query.dtype) - sin = sin.to(query.dtype) - query = self.rope(query, cos, sin) - key = self.rope(key, cos, sin) + query, key = apply_rope_to_qk(self.rope, query, key, (cos, sin)) # Cast to correct dtype dtype = query.dtype query, key = query.to(dtype), key.to(dtype) From 876d49279a20510893dd714d5867229b89cb8468 Mon Sep 17 00:00:00 2001 From: Lancer Date: Tue, 10 Mar 2026 19:01:13 +0800 Subject: [PATCH 09/11] upd Signed-off-by: Lancer --- .../flux2_klein/flux2_klein_transformer.py | 75 +++++++++++++------ 1 file changed, 53 insertions(+), 22 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index f17a231d78b..dd9ba57d776 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -40,7 +40,7 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention -from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig from vllm_omni.diffusion.distributed.sp_plan import ( SequenceParallelInput, SequenceParallelOutput, @@ -106,6 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Flux2Attention(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, query_dim: int, heads: int = 8, dim_head: int = 64, @@ -120,6 +121,7 @@ def __init__( quant_config: "QuantizationConfig | None" = None, ): super().__init__() + self.parallel_config = parallel_config self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim @@ -216,8 +218,9 @@ def forward( encoder_query = self.norm_added_q(encoder_query) encoder_key = self.norm_added_k(encoder_key) + sp_size = self.parallel_config.sequence_parallel_size forward_ctx = get_forward_context() - use_sp_joint_attention = forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp + use_sp_joint_attention = sp_size is not None and sp_size > 1 and not forward_ctx.split_text_embed_in_sp if use_sp_joint_attention and image_rotary_emb is not None: cos, sin = image_rotary_emb @@ -310,6 +313,7 @@ class Flux2ParallelSelfAttention(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, query_dim: int, heads: int = 8, dim_head: int = 64, @@ -324,6 +328,7 @@ def __init__( quant_config: "QuantizationConfig | None" = None, ): super().__init__() + self.parallel_config = parallel_config self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim @@ -384,10 +389,11 @@ def forward( query = self.norm_q(query) key = self.norm_k(key) + sp_size = self.parallel_config.sequence_parallel_size forward_ctx = get_forward_context() text_seq_len = kwargs.get("text_seq_len", None) use_sp_single_stream = ( - forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None + sp_size is not None and sp_size > 1 and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None ) if use_sp_single_stream and image_rotary_emb is not None: @@ -443,6 +449,7 @@ def forward( class Flux2SingleTransformerBlock(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, dim: int, num_attention_heads: int, attention_head_dim: int, @@ -454,6 +461,7 @@ def __init__( super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.attn = Flux2ParallelSelfAttention( + parallel_config=parallel_config, query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, @@ -512,6 +520,7 @@ def forward( class Flux2TransformerBlock(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, dim: int, num_attention_heads: int, attention_head_dim: int, @@ -525,6 +534,7 @@ def __init__( self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.attn = Flux2Attention( + parallel_config=parallel_config, query_dim=dim, added_kv_proj_dim=dim, dim_head=attention_head_dim, @@ -626,29 +636,44 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: class Flux2RopePrepare(nn.Module): - """Prepares hidden_states and RoPE embeddings for sequence parallel. + """Prepares RoPE embeddings for sequence parallel. - This module encapsulates the input projection and RoPE computation for Flux.2-klein. - The key insight is that hidden_states and img_freqs must be sharded together - to maintain dimension alignment for RoPE computation in attention layers. - txt_freqs is kept replicated for dual-stream joint attention. + This module encapsulates the RoPE computation for Flux.2-klein. + For dual-stream attention, text components (outputs 0, 1) are replicated + across SP ranks, while image components (outputs 2, 3) are sharded. + + NOTE: The hidden_states projection is handled separately in forward() + so that _sp_plan can shard it at the root level. """ - def __init__(self, x_embedder: nn.Linear, pos_embed: Flux2PosEmbed): + def __init__(self, pos_embed: Flux2PosEmbed): super().__init__() - self.x_embedder = x_embedder self.pos_embed = pos_embed def forward( self, - hidden_states: torch.Tensor, img_ids: torch.Tensor, txt_ids: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - hidden_states = self.x_embedder(hidden_states) + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute RoPE embeddings for text and image sequences. + + Args: + img_ids: Image position IDs (img_seq_len, n_axes) + txt_ids: Text position IDs (txt_seq_len, n_axes) + + Returns: + Tuple of cosine / sine components for text & image + in the order: (txt_cos, txt_sin, img_cos, img_sin) + + NOTE: careful about output orders if this is refactored in the + future; we need to match the _sp_plan indices, since text + components (0 & 1) need to be replicated across SP ranks, + while image components (2 & 3) must be sharded. + """ img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) - return hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin + return txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin class Flux2TimestepGuidanceEmbeddings(nn.Module): @@ -713,14 +738,16 @@ class Flux2Transformer2DModel(nn.Module): _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] _sp_plan = { + "": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True), + }, "rope_prepare": { - 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True), + 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), - 4: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), }, "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), } - """SP plan: shard hidden_states/img_freqs at rope_prepare, gather output at proj_out.""" + """SP plan: shard hidden_states at root level, shard img_freqs at rope_prepare, gather output at proj_out.""" def __init__( self, @@ -783,11 +810,12 @@ def __init__( self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) - self.rope_prepare = Flux2RopePrepare(self.x_embedder, self.pos_embed) + self.rope_prepare = Flux2RopePrepare(self.pos_embed) self.transformer_blocks = nn.ModuleList( [ Flux2TransformerBlock( + parallel_config=self.parallel_config, dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -803,6 +831,7 @@ def __init__( self.single_transformer_blocks = nn.ModuleList( [ Flux2SingleTransformerBlock( + parallel_config=self.parallel_config, dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -843,7 +872,9 @@ def forward( num_txt_tokens = encoder_hidden_states.shape[1] - get_forward_context().split_text_embed_in_sp = False + sp_size = self.parallel_config.sequence_parallel_size + if sp_size is not None and sp_size > 1: + get_forward_context().split_text_embed_in_sp = False timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: @@ -860,11 +891,11 @@ def forward( if txt_ids.ndim == 3: txt_ids = txt_ids[0] - hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin = self.rope_prepare( - hidden_states, img_ids, txt_ids - ) + hidden_states = self.x_embedder(hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin = self.rope_prepare(img_ids, txt_ids) + concat_rotary_emb = ( torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), From 6a1e5f220bbc53a4d2918f2d4c3a23b77cc3356f Mon Sep 17 00:00:00 2001 From: Lancer Date: Wed, 11 Mar 2026 17:20:47 +0800 Subject: [PATCH 10/11] upd Signed-off-by: Lancer --- .../flux2_klein/flux2_klein_transformer.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index dd9ba57d776..c10f06751dd 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -240,10 +240,12 @@ def forward( joint_value=encoder_value, joint_strategy="front", ) - if attention_mask is not None: - if attention_mask.dim() == 3: - attention_mask = attention_mask.unsqueeze(1) - attn_metadata.attn_mask = attention_mask + hidden_states_mask: torch.Tensor | None = kwargs.get("hidden_states_mask", None) + encoder_hidden_states_mask: torch.Tensor | None = kwargs.get("encoder_hidden_states_mask", None) + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + if encoder_hidden_states_mask is not None: + attn_metadata.joint_attn_mask = encoder_hidden_states_mask hidden_states = self.attn(query, key, value, attn_metadata) hidden_states = hidden_states.flatten(2, 3).to(query.dtype) @@ -421,14 +423,16 @@ def forward( joint_value=text_value, joint_strategy="front", ) - if attention_mask is not None: - if attention_mask.dim() == 3: - attention_mask = attention_mask.unsqueeze(1) - attn_metadata.attn_mask = attention_mask + hidden_states_mask: torch.Tensor | None = kwargs.get("hidden_states_mask", None) + encoder_hidden_states_mask: torch.Tensor | None = kwargs.get("encoder_hidden_states_mask", None) + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + if encoder_hidden_states_mask is not None: + attn_metadata.joint_attn_mask = encoder_hidden_states_mask attn_output = self.attn(img_query, img_key, img_value, attn_metadata) else: - query, key = apply_rope_to_qk(self.rope, query, key, image_rotary_emb) # [S, D/2] + query, key = apply_rope_to_qk(self.rope, query, key, image_rotary_emb) attn_metadata = None if attention_mask is not None: @@ -901,24 +905,28 @@ def forward( torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), ) + # Create separate masks for image and text portions for Ulysses SP joint attention hidden_states_mask = None + encoder_hidden_states_mask = None ctx = get_forward_context() if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: batch_size = hidden_states.shape[0] img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size - full_seq_len = num_txt_tokens + img_padded_seq_len + hidden_states_mask = torch.ones( batch_size, - full_seq_len, + img_padded_seq_len, dtype=torch.bool, device=hidden_states.device, ) - hidden_states_mask[:, num_txt_tokens + ctx.sp_original_seq_len :] = False + hidden_states_mask[:, ctx.sp_original_seq_len :] = False if hidden_states_mask.all(): hidden_states_mask = None if hidden_states_mask is not None: - joint_attention_kwargs["attention_mask"] = hidden_states_mask + joint_attention_kwargs["hidden_states_mask"] = hidden_states_mask + if encoder_hidden_states_mask is not None: + joint_attention_kwargs["encoder_hidden_states_mask"] = encoder_hidden_states_mask for block in self.transformer_blocks: encoder_hidden_states, hidden_states = block( From 187255a6b7987c8ec955514df253faa6b8087582 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sun, 15 Mar 2026 10:20:27 +0800 Subject: [PATCH 11/11] upd Signed-off-by: Lancer --- docs/user_guide/diffusion/parallelism_acceleration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index bd2a58799f9..e9d65142945 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -35,7 +35,7 @@ The following table shows which models are currently supported by parallelism me | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) | ✅ | N/A | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ✅ | ✅ | N/A | -| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | +| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ✅ | ✅ | ❌ | ✅ | ❌ | N/A | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | N/A | | **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ |