From 4245ced5115fb5cc56d3095871db40308d75ae6e Mon Sep 17 00:00:00 2001 From: princepride Date: Sun, 15 Mar 2026 08:45:04 +0000 Subject: [PATCH 01/12] bagel support sp Signed-off-by: princepride --- .../models/bagel/bagel_transformer.py | 477 +++++++++++++++++- .../diffusion/models/bagel/pipeline_bagel.py | 4 +- vllm_omni/diffusion/registry.py | 2 +- 3 files changed, 454 insertions(+), 29 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 09144e05a91..d1f19fba559 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -13,6 +13,7 @@ import numpy as np import torch +import torch.distributed as dist from torch import nn from torch.nn.attention.flex_attention import flex_attention from transformers.models.qwen2.configuration_qwen2 import Qwen2Config @@ -31,12 +32,18 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.bagel import BagelConfig +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata as DiffusionAttentionMetadata from vllm_omni.diffusion.attention.backends.utils.fa import flash_attn_varlen_func +from vllm_omni.diffusion.attention.layer import Attention as DiffusionAttention +from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, get_classifier_free_guidance_world_size, + get_sequence_parallel_rank, + get_sp_group, ) +from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available from vllm_omni.diffusion.layers.rope import RotaryEmbedding @@ -266,10 +273,11 @@ class PackedAttentionMoT(nn.Module): - qkv_proj_moe_gen : stacks q_proj_moe_gen + k_proj_moe_gen + v_proj_moe_gen (gen vae) """ - def __init__(self, config, layer_idx: int | None = None): + def __init__(self, config, layer_idx: int | None = None, parallel_config: DiffusionParallelConfig | None = None): super().__init__() self.layer_idx = layer_idx self.hidden_size = config.hidden_size + self.parallel_config = parallel_config tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -321,6 +329,138 @@ def __init__(self, config, layer_idx: int | None = None): self.rotary_op = RotaryEmbedding(is_neox_style=True) + # SP (Ulysses / Ring) attention for generation mode denoising + sp_size = parallel_config.sequence_parallel_size if parallel_config is not None else 1 + if sp_size is not None and sp_size > 1: + self.sp_attn = DiffusionAttention( + num_heads=self.total_num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.total_num_kv_heads, + ) + else: + self.sp_attn = None + + def _is_sp_active(self) -> bool: + """Check if SP is active for this attention layer.""" + if self.sp_attn is None: + return False + if not is_forward_context_available(): + return False + return get_forward_context().sp_active + + def _forward_sp_gen( + self, + packed_query_sequence: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + past_key_values: NaiveCache | None, + packed_vae_token_indexes: torch.Tensor, + packed_text_indexes: torch.Tensor, + ) -> tuple[torch.Tensor, NaiveCache | None]: + """SP-aware attention for gen mode denoising. + + Converts packed format to batched (1, S, H, D) and uses the diffusion + Attention layer (Ulysses / Ring) with joint mechanism: + - Main Q/K/V: VAE tokens (split across SP ranks) + - Joint Q: text marker Q (replicated) + - Joint K/V: KV cache K/V + text marker K/V (replicated) + """ + packed_query_sequence = packed_query_sequence.to(torch.bfloat16) + + packed_text_query_sequence = packed_query_sequence[packed_text_indexes] + packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] + + # Project text tokens through base qkv + text_qkv, _ = self.qkv_proj(packed_text_query_sequence) + text_q, text_k, text_v = text_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Project vae tokens through moe_gen qkv + vae_qkv, _ = self.qkv_proj_moe_gen(packed_vae_query_sequence) + vae_q, vae_k, vae_v = vae_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Reshape to (tokens, heads, head_dim) + text_q = text_q.view(-1, self.num_heads, self.head_dim) + text_k = text_k.view(-1, self.num_kv_heads, self.head_dim) + text_v = text_v.view(-1, self.num_kv_heads, self.head_dim) + vae_q = vae_q.view(-1, self.num_heads, self.head_dim) + vae_k = vae_k.view(-1, self.num_kv_heads, self.head_dim) + vae_v = vae_v.view(-1, self.num_kv_heads, self.head_dim) + + # Apply QK norms + text_q = self.q_norm(text_q.to(torch.float32)) + text_k = self.k_norm(text_k.to(torch.float32)) + vae_q = self.q_norm_moe_gen(vae_q.to(torch.float32)) + vae_k = self.k_norm_moe_gen(vae_k.to(torch.float32)) + + # Apply RoPE - need to build per-token cos/sin for text and vae separately + # packed_query_position_embeddings are ordered as the packed sequence + cos_full, sin_full = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] + + # Extract cos/sin for text and vae positions + text_cos = cos_full[packed_text_indexes] + text_sin = sin_full[packed_text_indexes] + vae_cos = cos_full[packed_vae_token_indexes] + vae_sin = sin_full[packed_vae_token_indexes] + + text_q = self.rotary_op(text_q.to(text_cos.dtype).unsqueeze(0), text_cos, text_sin).squeeze(0) + text_k = self.rotary_op(text_k.to(text_cos.dtype).unsqueeze(0), text_cos, text_sin).squeeze(0) + vae_q = self.rotary_op(vae_q.to(vae_cos.dtype).unsqueeze(0), vae_cos, vae_sin).squeeze(0) + vae_k = self.rotary_op(vae_k.to(vae_cos.dtype).unsqueeze(0), vae_cos, vae_sin).squeeze(0) + + text_q = text_q.to(torch.bfloat16) + text_k = text_k.to(torch.bfloat16) + text_v = text_v.to(torch.bfloat16) + vae_q = vae_q.to(torch.bfloat16) + vae_k = vae_k.to(torch.bfloat16) + vae_v = vae_v.to(torch.bfloat16) + + # Build joint K/V: [kv_cache, text_markers] (replicated across SP ranks) + if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + cache_k = past_key_values.key_cache[self.layer_idx] + cache_v = past_key_values.value_cache[self.layer_idx] + joint_k = torch.cat([cache_k, text_k], dim=0).unsqueeze(0) + joint_v = torch.cat([cache_v, text_v], dim=0).unsqueeze(0) + else: + joint_k = text_k.unsqueeze(0) + joint_v = text_v.unsqueeze(0) + + # Reshape to batched (1, S, H, D) for diffusion Attention + vae_q_4d = vae_q.unsqueeze(0) + vae_k_4d = vae_k.unsqueeze(0) + vae_v_4d = vae_v.unsqueeze(0) + text_q_4d = text_q.unsqueeze(0) + + # Call SP-aware attention: VAE as main, text+cache as joint + attn_out = self.sp_attn( + vae_q_4d, + vae_k_4d, + vae_v_4d, + DiffusionAttentionMetadata( + joint_query=text_q_4d, + joint_key=joint_k, + joint_value=joint_v, + joint_strategy="front", + ), + ) + # attn_out: (1, text_len + local_vae_len, H, D) + text_len = text_q.shape[0] + attn_out = attn_out.squeeze(0) # (text_len + local_vae_len, H, D) + text_attn = attn_out[:text_len].reshape(text_len, self.q_size) + vae_attn = attn_out[text_len:].reshape(-1, self.q_size) + + # Apply output projections + text_out, _ = self.o_proj(text_attn) + vae_out, _ = self.o_proj_moe_gen(vae_attn) + + # Merge back into packed format + total_len = packed_query_sequence.shape[0] + full_output = text_out.new_zeros((total_len, self.hidden_size)) + full_output[packed_text_indexes] = text_out + full_output[packed_vae_token_indexes] = vae_out + + return full_output, past_key_values + def forward( self, packed_query_sequence: torch.Tensor, @@ -336,6 +476,23 @@ def forward( packed_vae_token_indexes=None, packed_text_indexes=None, ): + # SP path for gen-mode denoising (non-causal, no KV update) + if ( + mode == "gen" + and not update_past_key_values + and not is_causal + and self._is_sp_active() + and packed_vae_token_indexes is not None + and packed_text_indexes is not None + ): + return self._forward_sp_gen( + packed_query_sequence=packed_query_sequence, + packed_query_position_embeddings=packed_query_position_embeddings, + past_key_values=past_key_values, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + if mode == "und": qkv, _ = self.qkv_proj(packed_query_sequence) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -449,11 +606,12 @@ def __init__( config, layer_idx: int | None = None, attn_module: type[nn.Module] | None = PackedAttentionMoT, + parallel_config: DiffusionParallelConfig | None = None, ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = attn_module(config, layer_idx) + self.self_attn = attn_module(config, layer_idx, parallel_config=parallel_config) self.mlp = BagelMLP(config.hidden_size, config.intermediate_size, config.hidden_act) self.mlp_moe_gen = BagelMLP(config.hidden_size, config.intermediate_size, config.hidden_act) @@ -535,7 +693,7 @@ def forward( class Qwen2MoTModel(Qwen2PreTrainedModel): - def __init__(self, config): + def __init__(self, config, parallel_config: DiffusionParallelConfig | None = None): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -544,7 +702,7 @@ def __init__(self, config): self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [ - Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT) + Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT, parallel_config=parallel_config) for layer_idx in range(config.num_hidden_layers) ] ) @@ -626,9 +784,9 @@ def forward( class Qwen2MoTForCausalLM(Qwen2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] - def __init__(self, config): + def __init__(self, config, parallel_config: DiffusionParallelConfig | None = None): super().__init__(config) - self.model = Qwen2MoTModel(config) + self.model = Qwen2MoTModel(config, parallel_config=parallel_config) self.vocab_size = config.vocab_size # Initialize weights and apply final processing @@ -837,12 +995,24 @@ class Bagel(nn.Module): config_class = BagelConfig base_model_prefix = "bagel" - def __init__(self, language_model, vit_model, config: BagelConfig): + # Empty _sp_plan: signals SP support to the registry so that + # sp_plan_hooks_applied=True and sp_active is controlled by _sp_shard_depth. + # Actual SP logic is handled manually in the denoising methods. + _sp_plan = {} + + def __init__( + self, + language_model, + vit_model, + config: BagelConfig, + parallel_config: DiffusionParallelConfig | None = None, + ): super().__init__() self.language_model = language_model self.hidden_size = config.llm_config.hidden_size self.use_moe = "Mo" in config.llm_config.layer_module self.num_heads = config.llm_config.num_attention_heads + self.parallel_config = parallel_config if config.visual_gen: self.latent_patch_size = config.latent_patch_size @@ -869,6 +1039,94 @@ def __init__(self, language_model, vit_model, config: BagelConfig): self.config = config self._init_weights() + @property + def _sp_size(self) -> int: + if self.parallel_config is None: + return 1 + sp = self.parallel_config.sequence_parallel_size + return sp if sp is not None and sp > 1 else 1 + + def _sp_enter(self): + """Signal that we are entering an SP-sharded region (denoising).""" + if self._sp_size > 1: + ctx = get_forward_context() + ctx._sp_shard_depth += 1 + + def _sp_exit(self): + """Signal that we are leaving an SP-sharded region.""" + if self._sp_size > 1: + ctx = get_forward_context() + ctx._sp_shard_depth -= 1 + + def _split_vae_for_sp( + self, + x_t: torch.Tensor, + packed_vae_position_ids: torch.Tensor, + packed_vae_token_indexes: torch.Tensor, + packed_text_indexes: torch.Tensor, + packed_seqlens: torch.Tensor, + packed_position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Split VAE tokens across SP ranks for the denoising loop. + + Returns adjusted (x_t, packed_vae_position_ids, packed_vae_token_indexes, + packed_text_indexes, packed_seqlens, packed_position_ids) for the local rank. + """ + sp_size = self._sp_size + sp_rank = get_sequence_parallel_rank() + num_vae = x_t.shape[0] + assert num_vae % sp_size == 0, f"VAE token count {num_vae} not divisible by SP size {sp_size}" + chunk = num_vae // sp_size + start = sp_rank * chunk + end = start + chunk + + local_x_t = x_t[start:end] + local_vae_pos_ids = packed_vae_position_ids[start:end] + + # Rebuild local packed indices: + # packed sequence = [text_marker_0, local_vae_tokens..., text_marker_1] + # (for single-sample case with start_of_image / end_of_image markers) + num_text = packed_text_indexes.shape[0] + local_vae_len = chunk + local_total = num_text + local_vae_len + + # For the typical case: [start_img(0), vae_0(1), ..., vae_N(N), end_img(N+1)] + if num_text == 2: + local_text_indexes = torch.tensor([0, local_vae_len + 1], device=packed_text_indexes.device) + local_vae_indexes = torch.arange(1, 1 + local_vae_len, device=packed_vae_token_indexes.device) + else: + local_text_indexes = torch.arange(0, num_text, device=packed_text_indexes.device, dtype=torch.long) + local_vae_indexes = torch.arange( + num_text // 2, + num_text // 2 + local_vae_len, + device=packed_vae_token_indexes.device, + dtype=torch.long, + ) + + local_seqlens = torch.tensor([local_total], device=packed_seqlens.device, dtype=packed_seqlens.dtype) + + # Build local position IDs preserving global positions. + # Text markers keep their original positions; VAE tokens get + # the global positions for the local chunk. + text_pos_ids = packed_position_ids[packed_text_indexes] + vae_pos_ids_full = packed_position_ids[packed_vae_token_indexes] + local_vae_pos = vae_pos_ids_full[start:end] + local_position_ids = torch.zeros( + local_total, device=packed_position_ids.device, dtype=packed_position_ids.dtype + ) + local_position_ids[local_text_indexes] = text_pos_ids + local_position_ids[local_vae_indexes] = local_vae_pos + + return local_x_t, local_vae_pos_ids, local_vae_indexes, local_text_indexes, local_seqlens, local_position_ids + + def _gather_vae_for_sp(self, local_v_t: torch.Tensor) -> torch.Tensor: + """Gather VAE velocity outputs from all SP ranks.""" + sp_size = self._sp_size + gathered = [torch.zeros_like(local_v_t) for _ in range(sp_size)] + sp_group = get_sp_group() + dist.all_gather(gathered, local_v_t.contiguous(), group=sp_group.device_group) + return torch.cat(gathered, dim=0) + def _init_weights(self): if self.config.visual_gen: nn.init.constant_(self.llm2vae.weight, 0) @@ -1381,7 +1639,69 @@ def generate_image( cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes, ) - # ── Batched CFG mode (cfg_parallel_size=1) ── + # ── SP + CFG: sequential single-branch forwards ── + use_sp = self._sp_size > 1 + if use_sp and use_cfg_text: + for i, t in enumerate(timesteps): + timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) + in_cfg_window = t > cfg_interval[0] and t <= cfg_interval[1] + cfg_text_scale_ = cfg_text_scale if in_cfg_window else 1.0 + cfg_img_scale_ = cfg_img_scale if in_cfg_window else 1.0 + + common = dict( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_seqlens=packed_seqlens, + ) + + v_t = self._forward_flow_single_branch( + **common, + packed_indexes=packed_indexes, + packed_position_ids=packed_position_ids, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + ) + + if cfg_text_scale_ > 1.0: + cfg_text_v_t = self._forward_flow_single_branch( + **common, + packed_indexes=cfg_text_packed_query_indexes, + packed_position_ids=cfg_text_packed_position_ids, + key_values_lens=cfg_text_key_values_lens, + past_key_values=cfg_text_past_key_values, + packed_key_value_indexes=cfg_text_packed_key_value_indexes, + ) + cfg_img_v_t = None + if cfg_img_scale_ > 1.0: + cfg_img_v_t = self._forward_flow_single_branch( + **common, + packed_indexes=cfg_img_packed_query_indexes, + packed_position_ids=cfg_img_packed_position_ids, + key_values_lens=cfg_img_key_values_lens, + past_key_values=cfg_img_past_key_values, + packed_key_value_indexes=cfg_img_packed_key_value_indexes, + ) + v_t = self._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale_, + cfg_img_scale_, + cfg_renorm_type, + cfg_renorm_min, + ) + + x_t = x_t - v_t.to(x_t.device) * dts[i] + + unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) + return unpacked_latent + + # ── Batched CFG mode (cfg_parallel_size=1, no SP) ── cfg_batched = None if use_cfg_text: @@ -1681,7 +2001,79 @@ def _forward_flow_single_branch( Used by CFG parallel mode where each rank computes one branch. Returns the velocity v_t for the given branch. + Supports Ulysses / Ring SP when parallel_config.sequence_parallel_size > 1. """ + use_sp = self._sp_size > 1 + + if use_sp: + # Split VAE tokens across SP ranks + ( + local_x_t, + local_vae_pos_ids, + local_vae_indexes, + local_text_indexes, + local_seqlens, + local_position_ids, + ) = self._split_vae_for_sp( + x_t, + packed_vae_position_ids, + packed_vae_token_indexes, + packed_text_indexes, + packed_seqlens, + packed_position_ids, + ) + + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((int(local_seqlens.sum()), self.hidden_size)) + packed_sequence[local_text_indexes] = packed_text_embedding + + assert timestep.unique().shape[0] == 1 + packed_pos_embed = self.latent_pos_embed(local_vae_pos_ids) + local_timestep = timestep[: local_x_t.shape[0]] + packed_timestep_embeds = self.time_embedder(local_timestep) + x_t_emb = self.vae2llm(local_x_t) + packed_timestep_embeds + packed_pos_embed + if x_t_emb.dtype != packed_sequence.dtype: + x_t_emb = x_t_emb.to(packed_sequence.dtype) + packed_sequence[local_vae_indexes] = x_t_emb + + # Build local packed_indexes for KV cache merging + local_total = int(local_seqlens.sum()) + kv_len = int(key_values_lens.sum()) + local_packed_indexes = torch.arange( + kv_len, + kv_len + local_total, + device=packed_indexes.device, + dtype=packed_indexes.dtype, + ) + + extra_inputs = {} + if self.use_moe: + extra_inputs["mode"] = "gen" + extra_inputs["packed_vae_token_indexes"] = local_vae_indexes + extra_inputs["packed_text_indexes"] = local_text_indexes + + self._sp_enter() + try: + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=local_seqlens, + packed_query_position_ids=local_position_ids, + packed_query_indexes=local_packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + finally: + self._sp_exit() + + local_v_t = self.llm2vae(output.packed_query_sequence) + local_v_t = local_v_t[local_vae_indexes] + return self._gather_vae_for_sp(local_v_t) + + # Original non-SP path packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding @@ -1736,7 +2128,21 @@ def _forward_flow( cfg_img_scale: float = 1.0, cfg_batched: dict | None = None, ): + use_sp = self._sp_size > 1 + use_cfg = cfg_text_scale > 1.0 + + # When SP is active and batched CFG is needed, fall back to + # sequential single-branch forwards via the generate_image_parallel + # path. This path should not normally be reached because the pipeline + # selects _generate_image_parallel when SP is active. + if use_sp and use_cfg and cfg_batched is not None: + raise NotImplementedError( + "SP + batched CFG in _forward_flow is not supported. " + "Use CFG parallel mode (_generate_image_parallel) when SP is enabled." + ) + # Build query sequence (identical for all CFG branches) + x_t_raw = x_t packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding @@ -1753,7 +2159,6 @@ def _forward_flow( if self.use_moe: extra_inputs["mode"] = "gen" - use_cfg = cfg_text_scale > 1.0 cfg_text_v_t = None cfg_img_v_t = None @@ -1800,24 +2205,42 @@ def _forward_flow( ] else: # ── Single forward (no CFG or outside cfg_interval) ── - if self.use_moe: - extra_inputs["packed_vae_token_indexes"] = packed_vae_token_indexes - extra_inputs["packed_text_indexes"] = packed_text_indexes - - output = self.language_model.forward( - packed_query_sequence=packed_sequence, - query_lens=packed_seqlens, - packed_query_position_ids=packed_position_ids, - packed_query_indexes=packed_indexes, - past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, - update_past_key_values=False, - is_causal=False, - **extra_inputs, - ) - v_t = self.llm2vae(output.packed_query_sequence) - v_t = v_t[packed_vae_token_indexes] + if use_sp: + # Delegate to SP-aware single-branch path (use raw x_t, + # _forward_flow_single_branch does its own embedding) + v_t = self._forward_flow_single_branch( + x_t=x_t_raw, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_indexes=packed_indexes, + packed_position_ids=packed_position_ids, + packed_seqlens=packed_seqlens, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + ) + else: + if self.use_moe: + extra_inputs["packed_vae_token_indexes"] = packed_vae_token_indexes + extra_inputs["packed_text_indexes"] = packed_text_indexes + + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + v_t = self.llm2vae(output.packed_query_sequence) + v_t = v_t[packed_vae_token_indexes] # ── CFG combination ── if use_cfg: diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 9ed3481d71d..3fdbf650071 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -223,13 +223,15 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): int(required_max_id + 1), ) - self.language_model = Qwen2MoTForCausalLM(llm_config) + parallel_config = od_config.parallel_config if od_config else None + self.language_model = Qwen2MoTForCausalLM(llm_config, parallel_config=parallel_config) ae_params: AutoEncoderParams = default_ae_params() self.vae = AutoEncoder(ae_params) self.bagel = Bagel( language_model=self.language_model, vit_model=self.vit_model, + parallel_config=parallel_config, config=BagelConfig( llm_config=llm_config, vae_config=vae_cfg, diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 77e7885c01d..17707c5bf8b 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -241,7 +241,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - # Find transformer model(s) in the pipeline that have _sp_plan # Include transformer_2 for two-stage models (e.g., Wan MoE) - transformer_attrs = ["transformer", "transformer_2", "dit", "unet"] + transformer_attrs = ["transformer", "transformer_2", "dit", "unet", "bagel"] applied_count = 0 for attr in transformer_attrs: From 6e21d9c4b4ca0a3d8df92960b2f00a539292f692 Mon Sep 17 00:00:00 2001 From: princepride Date: Mon, 16 Mar 2026 09:25:41 +0000 Subject: [PATCH 02/12] simplify code Signed-off-by: princepride --- .../models/bagel/bagel_transformer.py | 90 +++++++++---------- 1 file changed, 41 insertions(+), 49 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index d1f19fba559..eab8ec3cd64 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -1701,6 +1701,29 @@ def generate_image( unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) return unpacked_latent + # ── SP without CFG: direct single-branch loop ── + if use_sp: + for i, t in enumerate(timesteps): + timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) + v_t = self._forward_flow_single_branch( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_indexes=packed_indexes, + packed_position_ids=packed_position_ids, + packed_seqlens=packed_seqlens, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + ) + x_t = x_t - v_t.to(x_t.device) * dts[i] + + unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) + return unpacked_latent + # ── Batched CFG mode (cfg_parallel_size=1, no SP) ── cfg_batched = None @@ -2128,21 +2151,7 @@ def _forward_flow( cfg_img_scale: float = 1.0, cfg_batched: dict | None = None, ): - use_sp = self._sp_size > 1 - use_cfg = cfg_text_scale > 1.0 - - # When SP is active and batched CFG is needed, fall back to - # sequential single-branch forwards via the generate_image_parallel - # path. This path should not normally be reached because the pipeline - # selects _generate_image_parallel when SP is active. - if use_sp and use_cfg and cfg_batched is not None: - raise NotImplementedError( - "SP + batched CFG in _forward_flow is not supported. " - "Use CFG parallel mode (_generate_image_parallel) when SP is enabled." - ) - # Build query sequence (identical for all CFG branches) - x_t_raw = x_t packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding @@ -2159,6 +2168,7 @@ def _forward_flow( if self.use_moe: extra_inputs["mode"] = "gen" + use_cfg = cfg_text_scale > 1.0 cfg_text_v_t = None cfg_img_v_t = None @@ -2205,42 +2215,24 @@ def _forward_flow( ] else: # ── Single forward (no CFG or outside cfg_interval) ── - if use_sp: - # Delegate to SP-aware single-branch path (use raw x_t, - # _forward_flow_single_branch does its own embedding) - v_t = self._forward_flow_single_branch( - x_t=x_t_raw, - timestep=timestep, - packed_vae_token_indexes=packed_vae_token_indexes, - packed_vae_position_ids=packed_vae_position_ids, - packed_text_ids=packed_text_ids, - packed_text_indexes=packed_text_indexes, - packed_indexes=packed_indexes, - packed_position_ids=packed_position_ids, - packed_seqlens=packed_seqlens, - key_values_lens=key_values_lens, - past_key_values=past_key_values, - packed_key_value_indexes=packed_key_value_indexes, - ) - else: - if self.use_moe: - extra_inputs["packed_vae_token_indexes"] = packed_vae_token_indexes - extra_inputs["packed_text_indexes"] = packed_text_indexes + if self.use_moe: + extra_inputs["packed_vae_token_indexes"] = packed_vae_token_indexes + extra_inputs["packed_text_indexes"] = packed_text_indexes - output = self.language_model.forward( - packed_query_sequence=packed_sequence, - query_lens=packed_seqlens, - packed_query_position_ids=packed_position_ids, - packed_query_indexes=packed_indexes, - past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, - update_past_key_values=False, - is_causal=False, - **extra_inputs, - ) - v_t = self.llm2vae(output.packed_query_sequence) - v_t = v_t[packed_vae_token_indexes] + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + v_t = self.llm2vae(output.packed_query_sequence) + v_t = v_t[packed_vae_token_indexes] # ── CFG combination ── if use_cfg: From 45ebd16dbf22cc5b950f5539e0a5f7e7feff8aca Mon Sep 17 00:00:00 2001 From: princepride Date: Wed, 18 Mar 2026 16:48:15 +0000 Subject: [PATCH 03/12] add receive kv cache distribute in kv_transfer_manager Signed-off-by: princepride --- .../models/bagel/bagel_transformer.py | 77 +++++++---------- vllm_omni/diffusion/registry.py | 2 +- .../worker/diffusion_model_runner.py | 6 +- .../omni_connectors/kv_transfer_manager.py | 86 +++++++++++++++++++ .../stage_configs/bagel_up2_ring2.yaml | 80 +++++++++++++++++ 5 files changed, 200 insertions(+), 51 deletions(-) create mode 100644 vllm_omni/model_executor/stage_configs/bagel_up2_ring2.yaml diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index eab8ec3cd64..3674646925b 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -995,11 +995,6 @@ class Bagel(nn.Module): config_class = BagelConfig base_model_prefix = "bagel" - # Empty _sp_plan: signals SP support to the registry so that - # sp_plan_hooks_applied=True and sp_active is controlled by _sp_shard_depth. - # Actual SP logic is handled manually in the denoising methods. - _sp_plan = {} - def __init__( self, language_model, @@ -1046,18 +1041,6 @@ def _sp_size(self) -> int: sp = self.parallel_config.sequence_parallel_size return sp if sp is not None and sp > 1 else 1 - def _sp_enter(self): - """Signal that we are entering an SP-sharded region (denoising).""" - if self._sp_size > 1: - ctx = get_forward_context() - ctx._sp_shard_depth += 1 - - def _sp_exit(self): - """Signal that we are leaving an SP-sharded region.""" - if self._sp_size > 1: - ctx = get_forward_context() - ctx._sp_shard_depth -= 1 - def _split_vae_for_sp( self, x_t: torch.Tensor, @@ -1084,24 +1067,18 @@ def _split_vae_for_sp( local_vae_pos_ids = packed_vae_position_ids[start:end] # Rebuild local packed indices: - # packed sequence = [text_marker_0, local_vae_tokens..., text_marker_1] - # (for single-sample case with start_of_image / end_of_image markers) + # packed sequence = [start_of_image, local_vae_tokens..., end_of_image] + # BAGEL always has exactly 2 text markers (start/end_of_image). num_text = packed_text_indexes.shape[0] + assert num_text == 2, f"Expected exactly 2 text markers (start/end_of_image), got {num_text}" + assert packed_seqlens.numel() == 1, ( + f"SP currently supports single-image batches only, got {packed_seqlens.numel()} sequences" + ) local_vae_len = chunk local_total = num_text + local_vae_len - # For the typical case: [start_img(0), vae_0(1), ..., vae_N(N), end_img(N+1)] - if num_text == 2: - local_text_indexes = torch.tensor([0, local_vae_len + 1], device=packed_text_indexes.device) - local_vae_indexes = torch.arange(1, 1 + local_vae_len, device=packed_vae_token_indexes.device) - else: - local_text_indexes = torch.arange(0, num_text, device=packed_text_indexes.device, dtype=torch.long) - local_vae_indexes = torch.arange( - num_text // 2, - num_text // 2 + local_vae_len, - device=packed_vae_token_indexes.device, - dtype=torch.long, - ) + local_text_indexes = torch.tensor([0, local_vae_len + 1], device=packed_text_indexes.device) + local_vae_indexes = torch.arange(1, 1 + local_vae_len, device=packed_vae_token_indexes.device) local_seqlens = torch.tensor([local_total], device=packed_seqlens.device, dtype=packed_seqlens.dtype) @@ -2059,9 +2036,17 @@ def _forward_flow_single_branch( x_t_emb = x_t_emb.to(packed_sequence.dtype) packed_sequence[local_vae_indexes] = x_t_emb - # Build local packed_indexes for KV cache merging + # Build local packed_indexes for KV cache merging. + # In the denoising loop packed_indexes is always contiguous + # (arange(kv_len, kv_len + total)), so we can safely build + # the local slice from scratch. local_total = int(local_seqlens.sum()) kv_len = int(key_values_lens.sum()) + original_total = int(packed_seqlens.sum()) + assert torch.equal( + packed_indexes, + torch.arange(kv_len, kv_len + original_total, device=packed_indexes.device, dtype=packed_indexes.dtype), + ), "packed_indexes must be contiguous for SP; non-contiguous layout not supported" local_packed_indexes = torch.arange( kv_len, kv_len + local_total, @@ -2075,22 +2060,18 @@ def _forward_flow_single_branch( extra_inputs["packed_vae_token_indexes"] = local_vae_indexes extra_inputs["packed_text_indexes"] = local_text_indexes - self._sp_enter() - try: - output = self.language_model.forward( - packed_query_sequence=packed_sequence, - query_lens=local_seqlens, - packed_query_position_ids=local_position_ids, - packed_query_indexes=local_packed_indexes, - past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, - update_past_key_values=False, - is_causal=False, - **extra_inputs, - ) - finally: - self._sp_exit() + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=local_seqlens, + packed_query_position_ids=local_position_ids, + packed_query_indexes=local_packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) local_v_t = self.llm2vae(output.packed_query_sequence) local_v_t = local_v_t[local_vae_indexes] diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 17707c5bf8b..77e7885c01d 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -241,7 +241,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - # Find transformer model(s) in the pipeline that have _sp_plan # Include transformer_2 for two-stage models (e.g., Wan MoE) - transformer_attrs = ["transformer", "transformer_2", "dit", "unet", "bagel"] + transformer_attrs = ["transformer", "transformer_2", "dit", "unet"] applied_count = 0 for attr in transformer_attrs: diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index accb173e1a0..09363d3d208 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -200,8 +200,10 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: use_hsdp = self.od_config.parallel_config.use_hsdp grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() with grad_context: - # The manager handles the check for need_recv_cache internally - self.kv_transfer_manager.receive_multi_kv_cache( + # Receive KV cache from upstream stage. + # Uses broadcast-aware variant so multi-GPU stages (e.g. SP) + # correctly distribute the cache from rank 0 to all workers. + self.kv_transfer_manager.receive_multi_kv_cache_distributed( req, cfg_kv_collect_func=getattr(self.od_config, "cfg_kv_collect_func", None), target_device=getattr(self.pipeline, "device", None), diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py index 2806f7315d1..34c91ec5911 100644 --- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -531,3 +531,89 @@ def receive_multi_kv_cache( logger.exception("Failed to collect CFG KV caches for %s", request_id) return primary_ok + + def receive_multi_kv_cache_distributed( + self, + req: Any, + cfg_kv_collect_func: Callable | None = None, + target_device: torch.device | None = None, + ) -> bool: + """Broadcast-aware wrapper around :meth:`receive_multi_kv_cache`. + + SharedMemory connector is single-reader: once rank 0 consumes the + segment it is deleted. For multi-GPU stages (e.g. sequence-parallel) + only rank 0 receives; the result is then broadcast to every other + rank via the world process-group. + + For single-worker stages this is equivalent to calling + :meth:`receive_multi_kv_cache` directly. + """ + from vllm_omni.diffusion.distributed.parallel_state import get_world_group + + world = get_world_group() + + if world.world_size <= 1: + return self.receive_multi_kv_cache(req, cfg_kv_collect_func, target_device) + + # --- rank 0: receive to CPU (needed for pickle-based broadcast) --- + if world.rank_in_group == 0: + self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu")) + + kv_payload: dict[str, object] = {} + for attr in ("past_key_values", "kv_metadata"): + val = getattr(req, attr, None) + if val is not None: + kv_payload[attr] = val + + if hasattr(req, "sampling_params") and req.sampling_params is not None: + for key in list(vars(req.sampling_params).keys()): + if (key.startswith("cfg_") and key.endswith("_past_key_values")) or key in ( + "past_key_values", + "kv_metadata", + ): + val = getattr(req.sampling_params, key, None) + if val is not None: + kv_payload[f"sp.{key}"] = val + + payload_list = [kv_payload] + torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group) + kv_payload = payload_list[0] + else: + payload_list: list[dict[str, object] | None] = [None] + torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group) + kv_payload = payload_list[0] + + # --- apply on ALL ranks (rank 0 also needs CPU→GPU move) --- + if not kv_payload: + return False + + for attr in ("past_key_values", "kv_metadata"): + val = kv_payload.get(attr) + if val is not None: + if target_device is not None: + val = _move_to_device(val, target_device) + setattr(req, attr, val) + + if hasattr(req, "sampling_params") and req.sampling_params is not None: + for key, val in kv_payload.items(): + if key.startswith("sp."): + if target_device is not None: + val = _move_to_device(val, target_device) + setattr(req.sampling_params, key[3:], val) + + return True + + +def _move_to_device(obj: object, device: torch.device) -> object: + """Recursively move tensors inside a KV cache object to *device*.""" + if isinstance(obj, torch.Tensor): + return obj.to(device).contiguous() if obj.device != device else obj + if hasattr(obj, "__dict__"): + for k, v in vars(obj).items(): + setattr(obj, k, _move_to_device(v, device)) + return obj + if isinstance(obj, dict): + return {k: _move_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, list): + return [_move_to_device(v, device) for v in obj] + return obj diff --git a/vllm_omni/model_executor/stage_configs/bagel_up2_ring2.yaml b/vllm_omni/model_executor/stage_configs/bagel_up2_ring2.yaml new file mode 100644 index 00000000000..573222b6836 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/bagel_up2_ring2.yaml @@ -0,0 +1,80 @@ +# Stage config for BAGEL SP: ulysses=2, ring=2 (4 GPUs) + +stage_args: + - stage_id: 0 + stage_type: llm + prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: OmniBagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.45 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: diffusion + cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches + runtime: + devices: "2,3,4,5" + max_batch_size: 1 + engine_args: + model_stage: dit + gpu_memory_utilization: 0.45 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + parallel_config: + ulysses_degree: 2 + ring_degree: 2 + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + edges: + - from: 0 + to: 1 + window_size: -1 From 076bc46f5651a23aa8abac0ea506188455ed7c5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Thu, 19 Mar 2026 09:50:41 +0800 Subject: [PATCH 04/12] Clean up comments in diffusion model runner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removed comments about KV cache handling in the inference context. Signed-off-by: 汪志鹏 --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 09363d3d208..5d260cb8c8b 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -200,9 +200,7 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: use_hsdp = self.od_config.parallel_config.use_hsdp grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() with grad_context: - # Receive KV cache from upstream stage. - # Uses broadcast-aware variant so multi-GPU stages (e.g. SP) - # correctly distribute the cache from rank 0 to all workers. + # The manager handles the check for need_recv_cache internally self.kv_transfer_manager.receive_multi_kv_cache_distributed( req, cfg_kv_collect_func=getattr(self.od_config, "cfg_kv_collect_func", None), From d5aa95cd86d386d7a2d6be4167b2928a160b2e41 Mon Sep 17 00:00:00 2001 From: princepride Date: Thu, 19 Mar 2026 08:43:04 +0000 Subject: [PATCH 05/12] change yaml name Signed-off-by: princepride --- .../stage_configs/{bagel_up2_ring2.yaml => bagel_usp2_ring2.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vllm_omni/model_executor/stage_configs/{bagel_up2_ring2.yaml => bagel_usp2_ring2.yaml} (100%) diff --git a/vllm_omni/model_executor/stage_configs/bagel_up2_ring2.yaml b/vllm_omni/model_executor/stage_configs/bagel_usp2_ring2.yaml similarity index 100% rename from vllm_omni/model_executor/stage_configs/bagel_up2_ring2.yaml rename to vllm_omni/model_executor/stage_configs/bagel_usp2_ring2.yaml From f985b99bd445e28250905a9d9e23483017d625c2 Mon Sep 17 00:00:00 2001 From: princepride Date: Thu, 19 Mar 2026 15:48:40 +0000 Subject: [PATCH 06/12] update docs Signed-off-by: princepride --- .../diffusion/parallelism_acceleration.md | 76 +++++++++++++++++++ docs/user_guide/diffusion_acceleration.md | 2 +- 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 20f09d2c09c..55d0550d867 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -39,6 +39,7 @@ The following table shows which models are currently supported by parallelism me | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | N/A | ✅ | | **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | ✅ | | **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | ❌ | | **DreamID-Omni** | `XuGuo699/DreamID-Omni` | ❌ | ❌ | ✅ | ❌ | ❌ | N/A | ❌ | !!! note "TP Limitations for Diffusion Models" @@ -282,6 +283,81 @@ To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** m | Hybrid Ulysses + Ring | 2 | 2 | 24.3s | 1.87x | +#### BAGEL Multi-Stage SP + +BAGEL uses a multi-stage pipeline (LLM prefill + diffusion), so sequence parallelism is configured via stage config YAML files rather than `DiffusionParallelConfig`. The LLM stage runs on a single GPU while the diffusion stage is parallelized across multiple GPUs. + +BAGEL SP supports Ulysses, Ring, and hybrid (Ulysses + Ring) modes. It is also compatible with TeaCache for cache acceleration. + +##### Offline Inference + +```python +from vllm_omni.entrypoints.omni import Omni + +# Ulysses=2, Ring=2 using a stage config that assigns 4 GPUs to the diffusion stage +omni = Omni( + model="ByteDance-Seed/BAGEL-7B-MoT", + stage_configs_path="vllm_omni/model_executor/stage_configs/bagel_usp2_ring2.yaml", +) + +params_list = omni.default_sampling_params_list +params_list[0].max_tokens = 1 +params_list[1].num_inference_steps = 15 +params_list[1].seed = 52 +params_list[1].extra_args = { + "cfg_text_scale": 4.0, + "cfg_img_scale": 1.5, +} + +outputs = list(omni.generate( + prompts=[{"prompt": "<|im_start|>A cute cat<|im_end|>", "modalities": ["image"]}], + sampling_params_list=params_list, +)) +``` + +##### SP + TeaCache + +BAGEL SP is compatible with TeaCache. Simply add `cache_backend` and `cache_config` when using a SP stage config: + +```python +from vllm_omni.entrypoints.omni import Omni + +# SP (Ulysses=2) + TeaCache +omni = Omni( + model="ByteDance-Seed/BAGEL-7B-MoT", + stage_configs_path="vllm_omni/model_executor/stage_configs/bagel_usp2.yaml", + cache_backend="tea_cache", + cache_config={"rel_l1_thresh": 0.2}, +) +``` + +!!! warning "SP + Cache-DiT" + SP combined with Cache-DiT is not recommended for BAGEL as it may produce degraded image quality. + +##### Stage Configs + +Pre-built stage configs are available under `vllm_omni/model_executor/stage_configs/`: + +- `bagel.yaml` — baseline (no SP) +- `bagel_usp2_ring2.yaml` — Ulysses=2, Ring=2 (4 GPUs for diffusion) + +You can create custom stage configs by setting `parallel_config.ulysses_degree` and/or `parallel_config.ring_degree` in the diffusion stage's `engine_args`. + +##### Benchmarks + +!!! note "Benchmark Disclaimer" + These benchmarks are provided for **general reference only**. Actual performance may vary based on hardware, model weights, and inference settings. + +The following benchmarks were measured on **BAGEL-7B-MoT** generating **1024×1024** images with **50** inference steps and CFG (text_scale=4.0, img_scale=1.5). Each diffusion worker uses ~27.20 GiB model VRAM. + +| Configuration | Diffusion GPUs | E2E Latency | Speedup | Notes | +|---------------|:--------------:|:-----------:|:-------:|-------| +| **Baseline (no SP)** | 1 | 19.04s | 1.0x | LLM on GPU 0, diffusion on GPU 1 | +| Ulysses=2 | 2 | 14.28s | 1.33x | | +| Ring=2 | 2 | 14.30s | 1.33x | | +| Ulysses=2 + Ring=2 | 4 | 14.45s | 1.32x | 4-GPU comm overhead offsets extra parallelism at this resolution | +| Ulysses=2 + TeaCache | 2 | 14.34s | 1.33x | SP compatible with TeaCache | + ### CFG-Parallel #### Offline Inference diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index e9179b5752a..39adc03a35a 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -65,7 +65,7 @@ The following table shows which models are currently supported by each accelerat | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ (TP=2 only) | ✅ | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | -| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | **NextStep-1.1** | `stepfun-ai/NextStep-1.1` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | | **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | From dd27b4a5a09b5c503e76cd2b9e449b4d3df70f2b Mon Sep 17 00:00:00 2001 From: princepride Date: Thu, 19 Mar 2026 16:30:53 +0000 Subject: [PATCH 07/12] update docs Signed-off-by: princepride --- .../diffusion/parallelism_acceleration.md | 75 ------------------- 1 file changed, 75 deletions(-) diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 55d0550d867..5b13d81e788 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -283,81 +283,6 @@ To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** m | Hybrid Ulysses + Ring | 2 | 2 | 24.3s | 1.87x | -#### BAGEL Multi-Stage SP - -BAGEL uses a multi-stage pipeline (LLM prefill + diffusion), so sequence parallelism is configured via stage config YAML files rather than `DiffusionParallelConfig`. The LLM stage runs on a single GPU while the diffusion stage is parallelized across multiple GPUs. - -BAGEL SP supports Ulysses, Ring, and hybrid (Ulysses + Ring) modes. It is also compatible with TeaCache for cache acceleration. - -##### Offline Inference - -```python -from vllm_omni.entrypoints.omni import Omni - -# Ulysses=2, Ring=2 using a stage config that assigns 4 GPUs to the diffusion stage -omni = Omni( - model="ByteDance-Seed/BAGEL-7B-MoT", - stage_configs_path="vllm_omni/model_executor/stage_configs/bagel_usp2_ring2.yaml", -) - -params_list = omni.default_sampling_params_list -params_list[0].max_tokens = 1 -params_list[1].num_inference_steps = 15 -params_list[1].seed = 52 -params_list[1].extra_args = { - "cfg_text_scale": 4.0, - "cfg_img_scale": 1.5, -} - -outputs = list(omni.generate( - prompts=[{"prompt": "<|im_start|>A cute cat<|im_end|>", "modalities": ["image"]}], - sampling_params_list=params_list, -)) -``` - -##### SP + TeaCache - -BAGEL SP is compatible with TeaCache. Simply add `cache_backend` and `cache_config` when using a SP stage config: - -```python -from vllm_omni.entrypoints.omni import Omni - -# SP (Ulysses=2) + TeaCache -omni = Omni( - model="ByteDance-Seed/BAGEL-7B-MoT", - stage_configs_path="vllm_omni/model_executor/stage_configs/bagel_usp2.yaml", - cache_backend="tea_cache", - cache_config={"rel_l1_thresh": 0.2}, -) -``` - -!!! warning "SP + Cache-DiT" - SP combined with Cache-DiT is not recommended for BAGEL as it may produce degraded image quality. - -##### Stage Configs - -Pre-built stage configs are available under `vllm_omni/model_executor/stage_configs/`: - -- `bagel.yaml` — baseline (no SP) -- `bagel_usp2_ring2.yaml` — Ulysses=2, Ring=2 (4 GPUs for diffusion) - -You can create custom stage configs by setting `parallel_config.ulysses_degree` and/or `parallel_config.ring_degree` in the diffusion stage's `engine_args`. - -##### Benchmarks - -!!! note "Benchmark Disclaimer" - These benchmarks are provided for **general reference only**. Actual performance may vary based on hardware, model weights, and inference settings. - -The following benchmarks were measured on **BAGEL-7B-MoT** generating **1024×1024** images with **50** inference steps and CFG (text_scale=4.0, img_scale=1.5). Each diffusion worker uses ~27.20 GiB model VRAM. - -| Configuration | Diffusion GPUs | E2E Latency | Speedup | Notes | -|---------------|:--------------:|:-----------:|:-------:|-------| -| **Baseline (no SP)** | 1 | 19.04s | 1.0x | LLM on GPU 0, diffusion on GPU 1 | -| Ulysses=2 | 2 | 14.28s | 1.33x | | -| Ring=2 | 2 | 14.30s | 1.33x | | -| Ulysses=2 + Ring=2 | 4 | 14.45s | 1.32x | 4-GPU comm overhead offsets extra parallelism at this resolution | -| Ulysses=2 + TeaCache | 2 | 14.34s | 1.33x | SP compatible with TeaCache | - ### CFG-Parallel #### Offline Inference From 3497b3338e77e9e1ef88027496547785dac1f35f Mon Sep 17 00:00:00 2001 From: princepride Date: Thu, 19 Mar 2026 16:48:40 +0000 Subject: [PATCH 08/12] Add L4 Signed-off-by: princepride --- .../online_serving/test_bagel_expansion.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py index 75e18ecfa3d..e2d75e0d199 100644 --- a/tests/e2e/online_serving/test_bagel_expansion.py +++ b/tests/e2e/online_serving/test_bagel_expansion.py @@ -7,6 +7,8 @@ - Cache-DiT - CFG-Parallel - Tensor-Parallel +- Ulysses-SP +- Ring-Attention assert_diffusion_response validates successful generation and the expected 512x512 resolution. @@ -31,7 +33,8 @@ def _get_diffusion_feature_cases(model: str): """Return L4 diffusion feature cases for Bagel. - TeaCache, Cache-DiT, CFG-Parallel, Tensor-Parallel. + TeaCache, Cache-DiT, CFG-Parallel, Tensor-Parallel, + Ulysses-SP, Ring-Attention. """ return [ @@ -87,6 +90,30 @@ def _get_diffusion_feature_cases(model: str): id="parallel_tp_2", marks=PARALLEL_FEATURE_MARKS, ), + # Ulysses-SP degree=2 (2 GPUs) + pytest.param( + OmniServerParams( + model=model, + server_args=[ + "--usp", + "2", + ], + ), + id="sp_ulysses_2", + marks=PARALLEL_FEATURE_MARKS, + ), + # Ring-Attention degree=2 (2 GPUs) + pytest.param( + OmniServerParams( + model=model, + server_args=[ + "--ring", + "2", + ], + ), + id="sp_ring_2", + marks=PARALLEL_FEATURE_MARKS, + ), ] @@ -108,6 +135,8 @@ def test_bagel( - Cache-DiT - CFG-Parallel (size=2) - Tensor-Parallel (size=2) + - Ulysses-SP (degree=2) + - Ring-Attention (degree=2) Validation is delegated to assert_diffusion_response in tests.conftest, which checks output dimensions and basic correctness. From 2f7c9ec9f4aa20610fe3e14ea27b399cce3773f0 Mon Sep 17 00:00:00 2001 From: princepride Date: Sat, 21 Mar 2026 14:22:18 +0000 Subject: [PATCH 09/12] Fix test mock and add broadcast comment per review feedback Co-Authored-By: Claude Opus 4.6 Signed-off-by: princepride --- tests/diffusion/test_diffusion_model_runner.py | 1 + vllm_omni/distributed/omni_connectors/kv_transfer_manager.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/tests/diffusion/test_diffusion_model_runner.py b/tests/diffusion/test_diffusion_model_runner.py index 2f9aadfd8ad..88b17147e85 100644 --- a/tests/diffusion/test_diffusion_model_runner.py +++ b/tests/diffusion/test_diffusion_model_runner.py @@ -59,6 +59,7 @@ def _make_runner(cache_backend, cache_backend_name: str, enable_cache_dit_summar runner.kv_transfer_manager = SimpleNamespace( receive_kv_cache=lambda req, target_device=None: None, receive_multi_kv_cache=lambda req, cfg_kv_collect_func=None, target_device=None: None, + receive_multi_kv_cache_distributed=lambda req, cfg_kv_collect_func=None, target_device=None: None, ) return runner diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py index 34c91ec5911..1f493843837 100644 --- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -576,6 +576,10 @@ def receive_multi_kv_cache_distributed( kv_payload[f"sp.{key}"] = val payload_list = [kv_payload] + # Use broadcast_object_list (pickle-based) instead of broadcast_tensor_dict + # because the KV cache is a heterogeneous nested structure (NaiveCache objects + # with metadata + tensors), not a flat tensor dict. This runs once before + # the denoising loop so the serialization cost is negligible. torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group) kv_payload = payload_list[0] else: From 8a6fa3783d76c3e17cef8358556174ab0f4a3ee0 Mon Sep 17 00:00:00 2001 From: princepride Date: Mon, 23 Mar 2026 09:04:41 +0000 Subject: [PATCH 10/12] update bagel usp yaml Signed-off-by: princepride --- .../stage_configs/{bagel_usp2_ring2.yaml => bagel_usp2.yaml} | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) rename vllm_omni/model_executor/stage_configs/{bagel_usp2_ring2.yaml => bagel_usp2.yaml} (96%) diff --git a/vllm_omni/model_executor/stage_configs/bagel_usp2_ring2.yaml b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml similarity index 96% rename from vllm_omni/model_executor/stage_configs/bagel_usp2_ring2.yaml rename to vllm_omni/model_executor/stage_configs/bagel_usp2.yaml index 573222b6836..33f332b0d44 100644 --- a/vllm_omni/model_executor/stage_configs/bagel_usp2_ring2.yaml +++ b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml @@ -40,7 +40,8 @@ stage_args: stage_type: diffusion cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches runtime: - devices: "2,3,4,5" + # devices: "0,1,2,3" + devices: "0,1" max_batch_size: 1 engine_args: model_stage: dit @@ -54,7 +55,7 @@ stage_args: tensor_parallel_size: 1 parallel_config: ulysses_degree: 2 - ring_degree: 2 + # ring_degree: 2 omni_kv_config: need_recv_cache: true engine_input_source: [0] From 1fa34c2bf20380654872cb5d51b3683930d8365b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:46:53 +0800 Subject: [PATCH 11/12] Apply suggestion from @wtomin Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com> --- vllm_omni/model_executor/stage_configs/bagel_usp2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml index 33f332b0d44..632c227f360 100644 --- a/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml +++ b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml @@ -1,4 +1,4 @@ -# Stage config for BAGEL SP: ulysses=2, ring=2 (4 GPUs) +# Stage config for BAGEL SP: ulysses=2 (2 GPUs) stage_args: - stage_id: 0 From cc30a5029cf413974f269c95164966c771fb5148 Mon Sep 17 00:00:00 2001 From: princepride Date: Mon, 23 Mar 2026 12:00:07 +0000 Subject: [PATCH 12/12] fix docs bug Signed-off-by: princepride --- docs/.nav.yml | 2 ++ vllm_omni/model_executor/models/hunyuan_image3/siglip2.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/.nav.yml b/docs/.nav.yml index b7d08e77e91..7550b6f3b41 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -91,10 +91,12 @@ nav: - design/feature/cache_dit.md - design/feature/teacache.md - design/feature/async_chunk_design.md + - design/feature/vae_parallel.md - Module Design: - design/module/ar_module.md - design/module/dit_module.md - design/module/entrypoint_module.md + - design/module/async_omni_architecture.md - Docs Guide: contributing/DOCS_GUIDE.md - API Reference: - api/README.md diff --git a/vllm_omni/model_executor/models/hunyuan_image3/siglip2.py b/vllm_omni/model_executor/models/hunyuan_image3/siglip2.py index e608baeaa0d..09b6c1d6401 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/siglip2.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/siglip2.py @@ -393,7 +393,7 @@ def forward( Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) + What are attention masks? See HuggingFace documentation for details. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail.