diff --git a/docs/.nav.yml b/docs/.nav.yml index bfa9365f6f6..fd4bec2ef0b 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -92,11 +92,13 @@ nav: - design/feature/cache_dit.md - design/feature/teacache.md - design/feature/async_chunk_design.md + - design/feature/vae_parallel.md - design/feature/diffusion_step_execution.md - Module Design: - design/module/ar_module.md - design/module/dit_module.md - design/module/entrypoint_module.md + - design/module/async_omni_architecture.md - Docs Guide: contributing/DOCS_GUIDE.md - API Reference: - api/README.md diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 12889719f62..4b5f6bf8476 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -39,6 +39,7 @@ The following table shows which models are currently supported by parallelism me | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ✅ | ✅ | ❌ | N/A | ✅ | | **FLUX.2-dev** | `black-forest-labs/FLUX.2-dev` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | ✅ | | **HunyuanImage3.0** | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | ❌ | | **DreamID-Omni** | `XuGuo699/DreamID-Omni` | ❌ | ❌ | ✅ | ❌ | ❌ | N/A | ❌ | | **OmniGen2** | `OmniGen2/OmniGen2` | ❌ | ❌ | ❌ | ✅ | ❌ | N/A | ❌ | diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 0140819f028..ad97909c571 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -66,7 +66,7 @@ The following table shows which models are currently supported by each accelerat | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ (TP=2 only) | ✅ | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | -| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | **NextStep-1.1** | `stepfun-ai/NextStep-1.1` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | | **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | diff --git a/tests/diffusion/test_diffusion_model_runner.py b/tests/diffusion/test_diffusion_model_runner.py index 2f9aadfd8ad..88b17147e85 100644 --- a/tests/diffusion/test_diffusion_model_runner.py +++ b/tests/diffusion/test_diffusion_model_runner.py @@ -59,6 +59,7 @@ def _make_runner(cache_backend, cache_backend_name: str, enable_cache_dit_summar runner.kv_transfer_manager = SimpleNamespace( receive_kv_cache=lambda req, target_device=None: None, receive_multi_kv_cache=lambda req, cfg_kv_collect_func=None, target_device=None: None, + receive_multi_kv_cache_distributed=lambda req, cfg_kv_collect_func=None, target_device=None: None, ) return runner diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py index 75e18ecfa3d..e2d75e0d199 100644 --- a/tests/e2e/online_serving/test_bagel_expansion.py +++ b/tests/e2e/online_serving/test_bagel_expansion.py @@ -7,6 +7,8 @@ - Cache-DiT - CFG-Parallel - Tensor-Parallel +- Ulysses-SP +- Ring-Attention assert_diffusion_response validates successful generation and the expected 512x512 resolution. @@ -31,7 +33,8 @@ def _get_diffusion_feature_cases(model: str): """Return L4 diffusion feature cases for Bagel. - TeaCache, Cache-DiT, CFG-Parallel, Tensor-Parallel. + TeaCache, Cache-DiT, CFG-Parallel, Tensor-Parallel, + Ulysses-SP, Ring-Attention. """ return [ @@ -87,6 +90,30 @@ def _get_diffusion_feature_cases(model: str): id="parallel_tp_2", marks=PARALLEL_FEATURE_MARKS, ), + # Ulysses-SP degree=2 (2 GPUs) + pytest.param( + OmniServerParams( + model=model, + server_args=[ + "--usp", + "2", + ], + ), + id="sp_ulysses_2", + marks=PARALLEL_FEATURE_MARKS, + ), + # Ring-Attention degree=2 (2 GPUs) + pytest.param( + OmniServerParams( + model=model, + server_args=[ + "--ring", + "2", + ], + ), + id="sp_ring_2", + marks=PARALLEL_FEATURE_MARKS, + ), ] @@ -108,6 +135,8 @@ def test_bagel( - Cache-DiT - CFG-Parallel (size=2) - Tensor-Parallel (size=2) + - Ulysses-SP (degree=2) + - Ring-Attention (degree=2) Validation is delegated to assert_diffusion_response in tests.conftest, which checks output dimensions and basic correctness. diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 09144e05a91..3674646925b 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -13,6 +13,7 @@ import numpy as np import torch +import torch.distributed as dist from torch import nn from torch.nn.attention.flex_attention import flex_attention from transformers.models.qwen2.configuration_qwen2 import Qwen2Config @@ -31,12 +32,18 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.bagel import BagelConfig +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata as DiffusionAttentionMetadata from vllm_omni.diffusion.attention.backends.utils.fa import flash_attn_varlen_func +from vllm_omni.diffusion.attention.layer import Attention as DiffusionAttention +from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, get_classifier_free_guidance_world_size, + get_sequence_parallel_rank, + get_sp_group, ) +from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available from vllm_omni.diffusion.layers.rope import RotaryEmbedding @@ -266,10 +273,11 @@ class PackedAttentionMoT(nn.Module): - qkv_proj_moe_gen : stacks q_proj_moe_gen + k_proj_moe_gen + v_proj_moe_gen (gen vae) """ - def __init__(self, config, layer_idx: int | None = None): + def __init__(self, config, layer_idx: int | None = None, parallel_config: DiffusionParallelConfig | None = None): super().__init__() self.layer_idx = layer_idx self.hidden_size = config.hidden_size + self.parallel_config = parallel_config tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -321,6 +329,138 @@ def __init__(self, config, layer_idx: int | None = None): self.rotary_op = RotaryEmbedding(is_neox_style=True) + # SP (Ulysses / Ring) attention for generation mode denoising + sp_size = parallel_config.sequence_parallel_size if parallel_config is not None else 1 + if sp_size is not None and sp_size > 1: + self.sp_attn = DiffusionAttention( + num_heads=self.total_num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.total_num_kv_heads, + ) + else: + self.sp_attn = None + + def _is_sp_active(self) -> bool: + """Check if SP is active for this attention layer.""" + if self.sp_attn is None: + return False + if not is_forward_context_available(): + return False + return get_forward_context().sp_active + + def _forward_sp_gen( + self, + packed_query_sequence: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + past_key_values: NaiveCache | None, + packed_vae_token_indexes: torch.Tensor, + packed_text_indexes: torch.Tensor, + ) -> tuple[torch.Tensor, NaiveCache | None]: + """SP-aware attention for gen mode denoising. + + Converts packed format to batched (1, S, H, D) and uses the diffusion + Attention layer (Ulysses / Ring) with joint mechanism: + - Main Q/K/V: VAE tokens (split across SP ranks) + - Joint Q: text marker Q (replicated) + - Joint K/V: KV cache K/V + text marker K/V (replicated) + """ + packed_query_sequence = packed_query_sequence.to(torch.bfloat16) + + packed_text_query_sequence = packed_query_sequence[packed_text_indexes] + packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] + + # Project text tokens through base qkv + text_qkv, _ = self.qkv_proj(packed_text_query_sequence) + text_q, text_k, text_v = text_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Project vae tokens through moe_gen qkv + vae_qkv, _ = self.qkv_proj_moe_gen(packed_vae_query_sequence) + vae_q, vae_k, vae_v = vae_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Reshape to (tokens, heads, head_dim) + text_q = text_q.view(-1, self.num_heads, self.head_dim) + text_k = text_k.view(-1, self.num_kv_heads, self.head_dim) + text_v = text_v.view(-1, self.num_kv_heads, self.head_dim) + vae_q = vae_q.view(-1, self.num_heads, self.head_dim) + vae_k = vae_k.view(-1, self.num_kv_heads, self.head_dim) + vae_v = vae_v.view(-1, self.num_kv_heads, self.head_dim) + + # Apply QK norms + text_q = self.q_norm(text_q.to(torch.float32)) + text_k = self.k_norm(text_k.to(torch.float32)) + vae_q = self.q_norm_moe_gen(vae_q.to(torch.float32)) + vae_k = self.k_norm_moe_gen(vae_k.to(torch.float32)) + + # Apply RoPE - need to build per-token cos/sin for text and vae separately + # packed_query_position_embeddings are ordered as the packed sequence + cos_full, sin_full = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] + + # Extract cos/sin for text and vae positions + text_cos = cos_full[packed_text_indexes] + text_sin = sin_full[packed_text_indexes] + vae_cos = cos_full[packed_vae_token_indexes] + vae_sin = sin_full[packed_vae_token_indexes] + + text_q = self.rotary_op(text_q.to(text_cos.dtype).unsqueeze(0), text_cos, text_sin).squeeze(0) + text_k = self.rotary_op(text_k.to(text_cos.dtype).unsqueeze(0), text_cos, text_sin).squeeze(0) + vae_q = self.rotary_op(vae_q.to(vae_cos.dtype).unsqueeze(0), vae_cos, vae_sin).squeeze(0) + vae_k = self.rotary_op(vae_k.to(vae_cos.dtype).unsqueeze(0), vae_cos, vae_sin).squeeze(0) + + text_q = text_q.to(torch.bfloat16) + text_k = text_k.to(torch.bfloat16) + text_v = text_v.to(torch.bfloat16) + vae_q = vae_q.to(torch.bfloat16) + vae_k = vae_k.to(torch.bfloat16) + vae_v = vae_v.to(torch.bfloat16) + + # Build joint K/V: [kv_cache, text_markers] (replicated across SP ranks) + if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + cache_k = past_key_values.key_cache[self.layer_idx] + cache_v = past_key_values.value_cache[self.layer_idx] + joint_k = torch.cat([cache_k, text_k], dim=0).unsqueeze(0) + joint_v = torch.cat([cache_v, text_v], dim=0).unsqueeze(0) + else: + joint_k = text_k.unsqueeze(0) + joint_v = text_v.unsqueeze(0) + + # Reshape to batched (1, S, H, D) for diffusion Attention + vae_q_4d = vae_q.unsqueeze(0) + vae_k_4d = vae_k.unsqueeze(0) + vae_v_4d = vae_v.unsqueeze(0) + text_q_4d = text_q.unsqueeze(0) + + # Call SP-aware attention: VAE as main, text+cache as joint + attn_out = self.sp_attn( + vae_q_4d, + vae_k_4d, + vae_v_4d, + DiffusionAttentionMetadata( + joint_query=text_q_4d, + joint_key=joint_k, + joint_value=joint_v, + joint_strategy="front", + ), + ) + # attn_out: (1, text_len + local_vae_len, H, D) + text_len = text_q.shape[0] + attn_out = attn_out.squeeze(0) # (text_len + local_vae_len, H, D) + text_attn = attn_out[:text_len].reshape(text_len, self.q_size) + vae_attn = attn_out[text_len:].reshape(-1, self.q_size) + + # Apply output projections + text_out, _ = self.o_proj(text_attn) + vae_out, _ = self.o_proj_moe_gen(vae_attn) + + # Merge back into packed format + total_len = packed_query_sequence.shape[0] + full_output = text_out.new_zeros((total_len, self.hidden_size)) + full_output[packed_text_indexes] = text_out + full_output[packed_vae_token_indexes] = vae_out + + return full_output, past_key_values + def forward( self, packed_query_sequence: torch.Tensor, @@ -336,6 +476,23 @@ def forward( packed_vae_token_indexes=None, packed_text_indexes=None, ): + # SP path for gen-mode denoising (non-causal, no KV update) + if ( + mode == "gen" + and not update_past_key_values + and not is_causal + and self._is_sp_active() + and packed_vae_token_indexes is not None + and packed_text_indexes is not None + ): + return self._forward_sp_gen( + packed_query_sequence=packed_query_sequence, + packed_query_position_embeddings=packed_query_position_embeddings, + past_key_values=past_key_values, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + if mode == "und": qkv, _ = self.qkv_proj(packed_query_sequence) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -449,11 +606,12 @@ def __init__( config, layer_idx: int | None = None, attn_module: type[nn.Module] | None = PackedAttentionMoT, + parallel_config: DiffusionParallelConfig | None = None, ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = attn_module(config, layer_idx) + self.self_attn = attn_module(config, layer_idx, parallel_config=parallel_config) self.mlp = BagelMLP(config.hidden_size, config.intermediate_size, config.hidden_act) self.mlp_moe_gen = BagelMLP(config.hidden_size, config.intermediate_size, config.hidden_act) @@ -535,7 +693,7 @@ def forward( class Qwen2MoTModel(Qwen2PreTrainedModel): - def __init__(self, config): + def __init__(self, config, parallel_config: DiffusionParallelConfig | None = None): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -544,7 +702,7 @@ def __init__(self, config): self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [ - Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT) + Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT, parallel_config=parallel_config) for layer_idx in range(config.num_hidden_layers) ] ) @@ -626,9 +784,9 @@ def forward( class Qwen2MoTForCausalLM(Qwen2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] - def __init__(self, config): + def __init__(self, config, parallel_config: DiffusionParallelConfig | None = None): super().__init__(config) - self.model = Qwen2MoTModel(config) + self.model = Qwen2MoTModel(config, parallel_config=parallel_config) self.vocab_size = config.vocab_size # Initialize weights and apply final processing @@ -837,12 +995,19 @@ class Bagel(nn.Module): config_class = BagelConfig base_model_prefix = "bagel" - def __init__(self, language_model, vit_model, config: BagelConfig): + def __init__( + self, + language_model, + vit_model, + config: BagelConfig, + parallel_config: DiffusionParallelConfig | None = None, + ): super().__init__() self.language_model = language_model self.hidden_size = config.llm_config.hidden_size self.use_moe = "Mo" in config.llm_config.layer_module self.num_heads = config.llm_config.num_attention_heads + self.parallel_config = parallel_config if config.visual_gen: self.latent_patch_size = config.latent_patch_size @@ -869,6 +1034,76 @@ def __init__(self, language_model, vit_model, config: BagelConfig): self.config = config self._init_weights() + @property + def _sp_size(self) -> int: + if self.parallel_config is None: + return 1 + sp = self.parallel_config.sequence_parallel_size + return sp if sp is not None and sp > 1 else 1 + + def _split_vae_for_sp( + self, + x_t: torch.Tensor, + packed_vae_position_ids: torch.Tensor, + packed_vae_token_indexes: torch.Tensor, + packed_text_indexes: torch.Tensor, + packed_seqlens: torch.Tensor, + packed_position_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Split VAE tokens across SP ranks for the denoising loop. + + Returns adjusted (x_t, packed_vae_position_ids, packed_vae_token_indexes, + packed_text_indexes, packed_seqlens, packed_position_ids) for the local rank. + """ + sp_size = self._sp_size + sp_rank = get_sequence_parallel_rank() + num_vae = x_t.shape[0] + assert num_vae % sp_size == 0, f"VAE token count {num_vae} not divisible by SP size {sp_size}" + chunk = num_vae // sp_size + start = sp_rank * chunk + end = start + chunk + + local_x_t = x_t[start:end] + local_vae_pos_ids = packed_vae_position_ids[start:end] + + # Rebuild local packed indices: + # packed sequence = [start_of_image, local_vae_tokens..., end_of_image] + # BAGEL always has exactly 2 text markers (start/end_of_image). + num_text = packed_text_indexes.shape[0] + assert num_text == 2, f"Expected exactly 2 text markers (start/end_of_image), got {num_text}" + assert packed_seqlens.numel() == 1, ( + f"SP currently supports single-image batches only, got {packed_seqlens.numel()} sequences" + ) + local_vae_len = chunk + local_total = num_text + local_vae_len + + local_text_indexes = torch.tensor([0, local_vae_len + 1], device=packed_text_indexes.device) + local_vae_indexes = torch.arange(1, 1 + local_vae_len, device=packed_vae_token_indexes.device) + + local_seqlens = torch.tensor([local_total], device=packed_seqlens.device, dtype=packed_seqlens.dtype) + + # Build local position IDs preserving global positions. + # Text markers keep their original positions; VAE tokens get + # the global positions for the local chunk. + text_pos_ids = packed_position_ids[packed_text_indexes] + vae_pos_ids_full = packed_position_ids[packed_vae_token_indexes] + local_vae_pos = vae_pos_ids_full[start:end] + local_position_ids = torch.zeros( + local_total, device=packed_position_ids.device, dtype=packed_position_ids.dtype + ) + local_position_ids[local_text_indexes] = text_pos_ids + local_position_ids[local_vae_indexes] = local_vae_pos + + return local_x_t, local_vae_pos_ids, local_vae_indexes, local_text_indexes, local_seqlens, local_position_ids + + def _gather_vae_for_sp(self, local_v_t: torch.Tensor) -> torch.Tensor: + """Gather VAE velocity outputs from all SP ranks.""" + sp_size = self._sp_size + gathered = [torch.zeros_like(local_v_t) for _ in range(sp_size)] + sp_group = get_sp_group() + dist.all_gather(gathered, local_v_t.contiguous(), group=sp_group.device_group) + return torch.cat(gathered, dim=0) + def _init_weights(self): if self.config.visual_gen: nn.init.constant_(self.llm2vae.weight, 0) @@ -1381,7 +1616,92 @@ def generate_image( cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes, ) - # ── Batched CFG mode (cfg_parallel_size=1) ── + # ── SP + CFG: sequential single-branch forwards ── + use_sp = self._sp_size > 1 + if use_sp and use_cfg_text: + for i, t in enumerate(timesteps): + timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) + in_cfg_window = t > cfg_interval[0] and t <= cfg_interval[1] + cfg_text_scale_ = cfg_text_scale if in_cfg_window else 1.0 + cfg_img_scale_ = cfg_img_scale if in_cfg_window else 1.0 + + common = dict( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_seqlens=packed_seqlens, + ) + + v_t = self._forward_flow_single_branch( + **common, + packed_indexes=packed_indexes, + packed_position_ids=packed_position_ids, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + ) + + if cfg_text_scale_ > 1.0: + cfg_text_v_t = self._forward_flow_single_branch( + **common, + packed_indexes=cfg_text_packed_query_indexes, + packed_position_ids=cfg_text_packed_position_ids, + key_values_lens=cfg_text_key_values_lens, + past_key_values=cfg_text_past_key_values, + packed_key_value_indexes=cfg_text_packed_key_value_indexes, + ) + cfg_img_v_t = None + if cfg_img_scale_ > 1.0: + cfg_img_v_t = self._forward_flow_single_branch( + **common, + packed_indexes=cfg_img_packed_query_indexes, + packed_position_ids=cfg_img_packed_position_ids, + key_values_lens=cfg_img_key_values_lens, + past_key_values=cfg_img_past_key_values, + packed_key_value_indexes=cfg_img_packed_key_value_indexes, + ) + v_t = self._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale_, + cfg_img_scale_, + cfg_renorm_type, + cfg_renorm_min, + ) + + x_t = x_t - v_t.to(x_t.device) * dts[i] + + unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) + return unpacked_latent + + # ── SP without CFG: direct single-branch loop ── + if use_sp: + for i, t in enumerate(timesteps): + timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) + v_t = self._forward_flow_single_branch( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_indexes=packed_indexes, + packed_position_ids=packed_position_ids, + packed_seqlens=packed_seqlens, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + ) + x_t = x_t - v_t.to(x_t.device) * dts[i] + + unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) + return unpacked_latent + + # ── Batched CFG mode (cfg_parallel_size=1, no SP) ── cfg_batched = None if use_cfg_text: @@ -1681,7 +2001,83 @@ def _forward_flow_single_branch( Used by CFG parallel mode where each rank computes one branch. Returns the velocity v_t for the given branch. + Supports Ulysses / Ring SP when parallel_config.sequence_parallel_size > 1. """ + use_sp = self._sp_size > 1 + + if use_sp: + # Split VAE tokens across SP ranks + ( + local_x_t, + local_vae_pos_ids, + local_vae_indexes, + local_text_indexes, + local_seqlens, + local_position_ids, + ) = self._split_vae_for_sp( + x_t, + packed_vae_position_ids, + packed_vae_token_indexes, + packed_text_indexes, + packed_seqlens, + packed_position_ids, + ) + + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((int(local_seqlens.sum()), self.hidden_size)) + packed_sequence[local_text_indexes] = packed_text_embedding + + assert timestep.unique().shape[0] == 1 + packed_pos_embed = self.latent_pos_embed(local_vae_pos_ids) + local_timestep = timestep[: local_x_t.shape[0]] + packed_timestep_embeds = self.time_embedder(local_timestep) + x_t_emb = self.vae2llm(local_x_t) + packed_timestep_embeds + packed_pos_embed + if x_t_emb.dtype != packed_sequence.dtype: + x_t_emb = x_t_emb.to(packed_sequence.dtype) + packed_sequence[local_vae_indexes] = x_t_emb + + # Build local packed_indexes for KV cache merging. + # In the denoising loop packed_indexes is always contiguous + # (arange(kv_len, kv_len + total)), so we can safely build + # the local slice from scratch. + local_total = int(local_seqlens.sum()) + kv_len = int(key_values_lens.sum()) + original_total = int(packed_seqlens.sum()) + assert torch.equal( + packed_indexes, + torch.arange(kv_len, kv_len + original_total, device=packed_indexes.device, dtype=packed_indexes.dtype), + ), "packed_indexes must be contiguous for SP; non-contiguous layout not supported" + local_packed_indexes = torch.arange( + kv_len, + kv_len + local_total, + device=packed_indexes.device, + dtype=packed_indexes.dtype, + ) + + extra_inputs = {} + if self.use_moe: + extra_inputs["mode"] = "gen" + extra_inputs["packed_vae_token_indexes"] = local_vae_indexes + extra_inputs["packed_text_indexes"] = local_text_indexes + + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=local_seqlens, + packed_query_position_ids=local_position_ids, + packed_query_indexes=local_packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + + local_v_t = self.llm2vae(output.packed_query_sequence) + local_v_t = local_v_t[local_vae_indexes] + return self._gather_vae_for_sp(local_v_t) + + # Original non-SP path packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index dc8a91b3bbe..29ac411f5e1 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -224,13 +224,15 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): int(required_max_id + 1), ) - self.language_model = Qwen2MoTForCausalLM(llm_config) + parallel_config = od_config.parallel_config if od_config else None + self.language_model = Qwen2MoTForCausalLM(llm_config, parallel_config=parallel_config) ae_params: AutoEncoderParams = default_ae_params() self.vae = AutoEncoder(ae_params) self.bagel = Bagel( language_model=self.language_model, vit_model=self.vit_model, + parallel_config=parallel_config, config=BagelConfig( llm_config=llm_config, vae_config=vae_cfg, diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index e3c27f94545..7be999492fa 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -215,7 +215,7 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() with grad_context: # The manager handles the check for need_recv_cache internally - self.kv_transfer_manager.receive_multi_kv_cache( + self.kv_transfer_manager.receive_multi_kv_cache_distributed( req, cfg_kv_collect_func=getattr(self.od_config, "cfg_kv_collect_func", None), target_device=getattr(self.pipeline, "device", None), diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py index 2806f7315d1..1f493843837 100644 --- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -531,3 +531,93 @@ def receive_multi_kv_cache( logger.exception("Failed to collect CFG KV caches for %s", request_id) return primary_ok + + def receive_multi_kv_cache_distributed( + self, + req: Any, + cfg_kv_collect_func: Callable | None = None, + target_device: torch.device | None = None, + ) -> bool: + """Broadcast-aware wrapper around :meth:`receive_multi_kv_cache`. + + SharedMemory connector is single-reader: once rank 0 consumes the + segment it is deleted. For multi-GPU stages (e.g. sequence-parallel) + only rank 0 receives; the result is then broadcast to every other + rank via the world process-group. + + For single-worker stages this is equivalent to calling + :meth:`receive_multi_kv_cache` directly. + """ + from vllm_omni.diffusion.distributed.parallel_state import get_world_group + + world = get_world_group() + + if world.world_size <= 1: + return self.receive_multi_kv_cache(req, cfg_kv_collect_func, target_device) + + # --- rank 0: receive to CPU (needed for pickle-based broadcast) --- + if world.rank_in_group == 0: + self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu")) + + kv_payload: dict[str, object] = {} + for attr in ("past_key_values", "kv_metadata"): + val = getattr(req, attr, None) + if val is not None: + kv_payload[attr] = val + + if hasattr(req, "sampling_params") and req.sampling_params is not None: + for key in list(vars(req.sampling_params).keys()): + if (key.startswith("cfg_") and key.endswith("_past_key_values")) or key in ( + "past_key_values", + "kv_metadata", + ): + val = getattr(req.sampling_params, key, None) + if val is not None: + kv_payload[f"sp.{key}"] = val + + payload_list = [kv_payload] + # Use broadcast_object_list (pickle-based) instead of broadcast_tensor_dict + # because the KV cache is a heterogeneous nested structure (NaiveCache objects + # with metadata + tensors), not a flat tensor dict. This runs once before + # the denoising loop so the serialization cost is negligible. + torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group) + kv_payload = payload_list[0] + else: + payload_list: list[dict[str, object] | None] = [None] + torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group) + kv_payload = payload_list[0] + + # --- apply on ALL ranks (rank 0 also needs CPU→GPU move) --- + if not kv_payload: + return False + + for attr in ("past_key_values", "kv_metadata"): + val = kv_payload.get(attr) + if val is not None: + if target_device is not None: + val = _move_to_device(val, target_device) + setattr(req, attr, val) + + if hasattr(req, "sampling_params") and req.sampling_params is not None: + for key, val in kv_payload.items(): + if key.startswith("sp."): + if target_device is not None: + val = _move_to_device(val, target_device) + setattr(req.sampling_params, key[3:], val) + + return True + + +def _move_to_device(obj: object, device: torch.device) -> object: + """Recursively move tensors inside a KV cache object to *device*.""" + if isinstance(obj, torch.Tensor): + return obj.to(device).contiguous() if obj.device != device else obj + if hasattr(obj, "__dict__"): + for k, v in vars(obj).items(): + setattr(obj, k, _move_to_device(v, device)) + return obj + if isinstance(obj, dict): + return {k: _move_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, list): + return [_move_to_device(v, device) for v in obj] + return obj diff --git a/vllm_omni/model_executor/models/hunyuan_image3/siglip2.py b/vllm_omni/model_executor/models/hunyuan_image3/siglip2.py index e608baeaa0d..09b6c1d6401 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/siglip2.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/siglip2.py @@ -393,7 +393,7 @@ def forward( Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) + What are attention masks? See HuggingFace documentation for details. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. diff --git a/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml new file mode 100644 index 00000000000..632c227f360 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml @@ -0,0 +1,81 @@ +# Stage config for BAGEL SP: ulysses=2 (2 GPUs) + +stage_args: + - stage_id: 0 + stage_type: llm + prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: OmniBagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.45 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: diffusion + cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches + runtime: + # devices: "0,1,2,3" + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + gpu_memory_utilization: 0.45 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + parallel_config: + ulysses_degree: 2 + # ring_degree: 2 + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + edges: + - from: 0 + to: 1 + window_size: -1