diff --git a/tests/diffusion/models/bagel/test_trajectory_recording.py b/tests/diffusion/models/bagel/test_trajectory_recording.py index a5ceb0cc298..730b918e41b 100644 --- a/tests/diffusion/models/bagel/test_trajectory_recording.py +++ b/tests/diffusion/models/bagel/test_trajectory_recording.py @@ -56,10 +56,7 @@ def _make_generate_args(num_tokens=NUM_TOKENS, hidden_dim=HIDDEN_DIM, cfg=False) packed_vae_token_indexes=torch.arange(2, seq_len, dtype=torch.long), packed_seqlens=torch.tensor([seq_len], dtype=torch.int), packed_position_ids=torch.arange(seq_len, dtype=torch.long), - packed_indexes=torch.arange(seq_len, dtype=torch.long), past_key_values=NaiveCache(1), - key_values_lens=torch.tensor([0], dtype=torch.int), - packed_key_value_indexes=torch.zeros(0, dtype=torch.long), num_timesteps=NUM_TIMESTEPS, timestep_shift=1.0, cfg_text_scale=1.0, @@ -68,11 +65,8 @@ def _make_generate_args(num_tokens=NUM_TOKENS, hidden_dim=HIDDEN_DIM, cfg=False) if cfg: base |= dict( cfg_text_scale=4.0, - cfg_text_packed_query_indexes=torch.arange(seq_len, dtype=torch.long), cfg_text_packed_position_ids=torch.arange(seq_len, dtype=torch.long), cfg_text_past_key_values=NaiveCache(1), - cfg_text_key_values_lens=torch.tensor([0], dtype=torch.int), - cfg_text_packed_key_value_indexes=torch.zeros(0, dtype=torch.long), ) return base diff --git a/vllm_omni/diffusion/attention/backends/sdpa.py b/vllm_omni/diffusion/attention/backends/sdpa.py index ab71e753b25..c41e4e39001 100644 --- a/vllm_omni/diffusion/attention/backends/sdpa.py +++ b/vllm_omni/diffusion/attention/backends/sdpa.py @@ -89,6 +89,7 @@ def __init__( ) -> None: self.causal = causal self.softmax_scale = softmax_scale + self.requires_gqa = num_heads != num_kv_heads if backend_kwargs: logger.warning("SDPAImpl ignoring backend_kwargs: %s", list(backend_kwargs.keys())) @@ -115,6 +116,7 @@ def _forward_impl( dropout_p=0.0, is_causal=self.causal, scale=self.softmax_scale, + enable_gqa=self.requires_gqa, ) out = output.permute(0, 2, 1, 3) return out diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 91259bc8ee3..1f0c38aee58 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -307,12 +307,9 @@ def extract_bagel_context( packed_vae_position_ids: torch.LongTensor, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, - packed_indexes: torch.LongTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, - key_values_lens: torch.IntTensor, past_key_values: Any, - packed_key_value_indexes: torch.LongTensor, **kwargs: Any, ) -> CacheContext: """ @@ -326,12 +323,9 @@ def extract_bagel_context( packed_vae_position_ids: Position IDs for VAE tokens packed_text_ids: Text token IDs packed_text_indexes: Indexes for text tokens in packed sequence - packed_indexes: Global indexes packed_position_ids: Global position IDs packed_seqlens: Sequence lengths - key_values_lens: KV cache lengths past_key_values: KV cache - packed_key_value_indexes: KV cache indexes **kwargs: Additional keyword arguments Returns: @@ -375,10 +369,7 @@ def run_transformer_blocks(): packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, - packed_query_indexes=packed_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=False, is_causal=False, **extra_inputs, diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index ecaba8e0866..c61db037b4c 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -14,7 +14,6 @@ import torch import torch.distributed as dist from torch import nn -from torch.nn.attention.flex_attention import flex_attention from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import ( Qwen2PreTrainedModel, @@ -36,7 +35,6 @@ from vllm.transformers_utils.configs.bagel import BagelConfig from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata as DiffusionAttentionMetadata -from vllm_omni.diffusion.attention.backends.utils.fa import flash_attn_varlen_func from vllm_omni.diffusion.attention.layer import Attention as DiffusionAttention from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.diffusion.distributed.parallel_state import ( @@ -200,11 +198,6 @@ def forward(self, x): return x -torch._dynamo.config.cache_size_limit = 512 -torch._dynamo.config.accumulated_cache_size_limit = 4096 -flex_attention = torch.compile(flex_attention) - - class Qwen2MoTConfig(Qwen2Config): """Configuration for Qwen2MoT (Mixture of Tokens) model. @@ -373,45 +366,55 @@ def __init__( self.rotary_op = RotaryEmbedding(is_neox_style=True) - # SP (Ulysses / Ring) attention for generation mode denoising - sp_size = parallel_config.sequence_parallel_size if parallel_config is not None else 1 - if sp_size is not None and sp_size > 1: - self.sp_attn = DiffusionAttention( - num_heads=self.total_num_heads, - head_size=self.head_dim, - softmax_scale=1.0 / (self.head_dim**0.5), - causal=False, - num_kv_heads=self.total_num_kv_heads, - ) - else: - self.sp_attn = None + self.attn_causal = DiffusionAttention( + num_heads=self.total_num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=True, + num_kv_heads=self.total_num_kv_heads, + ) + self.attn_noncausal = DiffusionAttention( + num_heads=self.total_num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.total_num_kv_heads, + ) def _is_sp_active(self) -> bool: """Check if SP is active for this attention layer.""" - if self.sp_attn is None: - return False if not is_forward_context_available(): return False return get_forward_context().sp_active - def _forward_sp_gen( + def _forward_gen( self, packed_query_sequence: torch.Tensor, packed_query_position_embeddings: torch.Tensor, past_key_values: NaiveCache | None, packed_vae_token_indexes: torch.Tensor, packed_text_indexes: torch.Tensor, + update_past_key_values: bool = False, ) -> tuple[torch.Tensor, NaiveCache | None]: - """SP-aware attention for gen mode denoising. - - Converts packed format to batched (1, S, H, D) and uses the diffusion - Attention layer (Ulysses / Ring) with joint mechanism: - - Main Q/K/V: VAE tokens (split across SP ranks) - - Joint Q: text marker Q (replicated) - - Joint K/V: KV cache K/V + text marker K/V (replicated) + """Forward pass for generation mode. + + This path does the following: + + 1. Apply qkv projection to the text seq & vae seqs + 2. Reshape both to 3D & apply RMS norms + 3. Apply RoPE to text / VAE components independently + 4. Create the full K/V; Bagel currently manages its own KV cache + (NaiveCache) independently since it is a diffusion model + 5. Apply non-causal attention, while taking sequence parallelism into account + 6. Apply output projections on the split parts + 7. Merge back into the packed format + 8. Update the NaiveCache. + + TODO (Alex): it would be best to remove packing from Bagel to simplify the code. + Currently we shouldn't need it in the model, and it would be ideal to handle + packing/batching etc in a more model agnostic way. """ packed_query_sequence = packed_query_sequence.to(torch.bfloat16) - packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] @@ -440,8 +443,6 @@ def _forward_sp_gen( # Apply RoPE - need to build per-token cos/sin for text and vae separately # packed_query_position_embeddings are ordered as the packed sequence cos_full, sin_full = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] - - # Extract cos/sin for text and vae positions text_cos = cos_full[packed_text_indexes] text_sin = sin_full[packed_text_indexes] vae_cos = cos_full[packed_vae_token_indexes] @@ -463,33 +464,35 @@ def _forward_sp_gen( if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: cache_k = past_key_values.key_cache[self.layer_idx] cache_v = past_key_values.value_cache[self.layer_idx] - joint_k = torch.cat([cache_k, text_k], dim=0).unsqueeze(0) - joint_v = torch.cat([cache_v, text_v], dim=0).unsqueeze(0) + ctx_k = torch.cat([cache_k, text_k], dim=0) + ctx_v = torch.cat([cache_v, text_v], dim=0) else: - joint_k = text_k.unsqueeze(0) - joint_v = text_v.unsqueeze(0) - - # Reshape to batched (1, S, H, D) for diffusion Attention - vae_q_4d = vae_q.unsqueeze(0) - vae_k_4d = vae_k.unsqueeze(0) - vae_v_4d = vae_v.unsqueeze(0) - text_q_4d = text_q.unsqueeze(0) - - # Call SP-aware attention: VAE as main, text+cache as joint - attn_out = self.sp_attn( - vae_q_4d, - vae_k_4d, - vae_v_4d, - DiffusionAttentionMetadata( - joint_query=text_q_4d, - joint_key=joint_k, - joint_value=joint_v, - joint_strategy="front", - ), - ) - # attn_out: (1, text_len + local_vae_len, H, D) + ctx_k = text_k + ctx_v = text_v + + # NOTE: we reshape to batched (1, S, H, D) for diffusion Attention + # attn_out should be: (1, text_len + local_vae_len, H, D) + if self._is_sp_active(): + # Joint mechanism keeps text+cache replicated across SP ranks + attn_out = self.attn_noncausal( + vae_q.unsqueeze(0), + vae_k.unsqueeze(0), + vae_v.unsqueeze(0), + DiffusionAttentionMetadata( + joint_query=text_q.unsqueeze(0), + joint_key=ctx_k.unsqueeze(0), + joint_value=ctx_v.unsqueeze(0), + joint_strategy="front", + ), + ) + else: + q = torch.cat([text_q, vae_q], dim=0).unsqueeze(0) + k = torch.cat([ctx_k, vae_k], dim=0).unsqueeze(0) + v = torch.cat([ctx_v, vae_v], dim=0).unsqueeze(0) + attn_out = self.attn_noncausal(q, k, v) + text_len = text_q.shape[0] - attn_out = attn_out.squeeze(0) # (text_len + local_vae_len, H, D) + attn_out = attn_out.squeeze(0) text_attn = attn_out[:text_len].reshape(text_len, self.q_size) vae_attn = attn_out[text_len:].reshape(-1, self.q_size) @@ -503,145 +506,108 @@ def _forward_sp_gen( full_output[packed_text_indexes] = text_out full_output[packed_vae_token_indexes] = vae_out + if update_past_key_values: + new_k = torch.cat([ctx_k, vae_k], dim=0) + new_v = torch.cat([ctx_v, vae_v], dim=0) + past_key_values.key_cache[self.layer_idx] = new_k + past_key_values.value_cache[self.layer_idx] = new_v + return full_output, past_key_values + def _forward_und( + self, + packed_query_sequence: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + past_key_values: NaiveCache | None, + is_causal: bool, + update_past_key_values: bool = True, + ) -> tuple[torch.Tensor, NaiveCache | None]: + """Forward pass for understanding mode. + + This path does the following (not hard to read): + + 1. Apply qkv projection to the text seq + 2. Reshape to 3D & apply RMS norms + 3. Apply RoPE + 4. Create the full K/V; Bagel currently manages its own KV cache + (NaiveCache) independently since it is a diffusion model + 5. Apply attention based on causality kwarg + 6. Apply output projection + 7. Update the NaiveCache. + """ + + qkv, _ = self.qkv_proj(packed_query_sequence) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_kv_heads, self.head_dim) + v = v.view(-1, self.num_kv_heads, self.head_dim) + q = self.q_norm(q) + k = self.k_norm(k) + + cos, sin = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] + q = self.rotary_op(q.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) + k = self.rotary_op(k.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) + + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + + if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + cache_k = past_key_values.key_cache[self.layer_idx] + cache_v = past_key_values.value_cache[self.layer_idx] + full_k = torch.cat([cache_k, k], dim=0) + full_v = torch.cat([cache_v, v], dim=0) + else: + full_k = k + full_v = v + + attn = self.attn_causal if is_causal else self.attn_noncausal + attn_out = attn( + q.unsqueeze(0), + full_k.unsqueeze(0), + full_v.unsqueeze(0), + ) + + attn_out = attn_out.squeeze(0).reshape(-1, self.q_size) + attn_out, _ = self.o_proj(attn_out) + + if update_past_key_values: + past_key_values.key_cache[self.layer_idx] = full_k + past_key_values.value_cache[self.layer_idx] = full_v + + return attn_out, past_key_values + def forward( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, - packed_query_indexes: torch.Tensor, past_key_values: NaiveCache | None = None, - key_values_lens: torch.Tensor | None = None, - packed_key_value_indexes: torch.Tensor | None = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ): - # SP path for gen-mode denoising (non-causal, no KV update) - if ( - mode == "gen" - and not update_past_key_values - and not is_causal - and self._is_sp_active() - and packed_vae_token_indexes is not None - and packed_text_indexes is not None - ): - return self._forward_sp_gen( + if mode == "gen": + if is_causal: + raise ValueError("Generation model for Bagel requires non-causal attention") + return self._forward_gen( packed_query_sequence=packed_query_sequence, packed_query_position_embeddings=packed_query_position_embeddings, past_key_values=past_key_values, packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, + update_past_key_values=update_past_key_values, ) - if mode == "und": - qkv, _ = self.qkv_proj(packed_query_sequence) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - packed_query_states = q.view(-1, self.num_heads, self.head_dim) - packed_key_states = k.view(-1, self.num_kv_heads, self.head_dim) - packed_value_states = v.view(-1, self.num_kv_heads, self.head_dim) - packed_query_states = self.q_norm(packed_query_states) - packed_key_states = self.k_norm(packed_key_states) - elif mode == "gen": - packed_query_sequence = packed_query_sequence.to(torch.bfloat16) - - packed_text_query_sequence = packed_query_sequence[packed_text_indexes] - packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] - - # Project text tokens through base qkv - text_qkv, _ = self.qkv_proj(packed_text_query_sequence) - text_q, text_k, text_v = text_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - # Project vae tokens through moe_gen qkv - vae_qkv, _ = self.qkv_proj_moe_gen(packed_vae_query_sequence) - vae_q, vae_k, vae_v = vae_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - # Merge into packed tensors - total_len = packed_query_sequence.shape[0] - packed_query_states = packed_query_sequence.new_zeros((total_len, self.q_size)) - packed_key_states = packed_query_sequence.new_zeros((total_len, self.kv_size)) - packed_value_states = packed_query_sequence.new_zeros((total_len, self.kv_size)) - - packed_query_states[packed_text_indexes] = text_q - packed_query_states[packed_vae_token_indexes] = vae_q - packed_key_states[packed_text_indexes] = text_k - packed_key_states[packed_vae_token_indexes] = vae_k - packed_value_states[packed_text_indexes] = text_v - packed_value_states[packed_vae_token_indexes] = vae_v - - packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) - packed_key_states = packed_key_states.view(-1, self.num_kv_heads, self.head_dim) - packed_value_states = packed_value_states.view(-1, self.num_kv_heads, self.head_dim) - - packed_query_states = packed_query_states.to(torch.float32) - packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) - packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen( - packed_query_states[packed_vae_token_indexes] - ) - - packed_key_states = packed_key_states.to(torch.float32) - packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) - packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen( - packed_key_states[packed_vae_token_indexes] - ) - - cos, sin = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] - packed_query_states = self.rotary_op(packed_query_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) - packed_key_states = self.rotary_op(packed_key_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) - - packed_query_states = packed_query_states.to(torch.bfloat16) - packed_key_states = packed_key_states.to(torch.bfloat16) - packed_value_states = packed_value_states.to(torch.bfloat16) - - if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: - past_key_states = past_key_values.key_cache[self.layer_idx] - past_value_states = past_key_values.value_cache[self.layer_idx] - - seqlens = sum(query_lens) + sum(key_values_lens) - merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_kv_heads, self.head_dim]) - merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_kv_heads, self.head_dim]) - merged_key_states[packed_query_indexes] = packed_key_states - merged_key_states[packed_key_value_indexes] = past_key_states - merged_value_states[packed_query_indexes] = packed_value_states - merged_value_states[packed_key_value_indexes] = past_value_states - key_values_lens = key_values_lens + query_lens - else: - merged_key_states = packed_key_states - merged_value_states = packed_value_states - key_values_lens = query_lens - - cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) - cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) - - packed_attn_output = flash_attn_varlen_func( - q=packed_query_states, - k=merged_key_states, - v=merged_value_states, - cu_seqlens_q=cu_seqlens_q.to(torch.int32), - cu_seqlens_k=cu_seqlens_k.to(torch.int32), - max_seqlen_q=max(query_lens).item(), - max_seqlen_k=max(key_values_lens).item(), - causal=is_causal, + return self._forward_und( + packed_query_sequence=packed_query_sequence, + packed_query_position_embeddings=packed_query_position_embeddings, + past_key_values=past_key_values, + is_causal=is_causal, + update_past_key_values=update_past_key_values, ) - packed_attn_output = packed_attn_output.reshape(-1, self.q_size) - if mode == "und": - packed_attn_output, _ = self.o_proj(packed_attn_output) - elif mode == "gen": - text_out, _ = self.o_proj(packed_attn_output[packed_text_indexes]) - vae_out, _ = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes]) - full_output = text_out.new_zeros((packed_attn_output.shape[0], self.hidden_size)) - full_output[packed_text_indexes] = text_out - full_output[packed_vae_token_indexes] = vae_out - packed_attn_output = full_output - - if update_past_key_values: - past_key_values.key_cache[self.layer_idx] = merged_key_states - past_key_values.value_cache[self.layer_idx] = merged_value_states - - return packed_attn_output, past_key_values class Qwen2MoTDecoderLayer(nn.Module): @@ -687,10 +653,7 @@ def forward( packed_query_sequence: torch.Tensor | None = None, query_lens: torch.Tensor = None, packed_query_position_embeddings: torch.Tensor = None, - packed_query_indexes: torch.Tensor = None, past_key_values: NaiveCache | None = None, - key_values_lens: torch.Tensor | None = None, - packed_key_value_indexes: torch.Tensor | None = None, update_past_key_values=True, is_causal=True, mode="und", @@ -717,10 +680,7 @@ def forward( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, - packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, mode=mode, @@ -801,10 +761,7 @@ def forward( packed_query_sequence: torch.Tensor | None = None, query_lens: torch.Tensor | None = None, packed_query_position_ids: torch.Tensor | None = None, - packed_query_indexes: torch.Tensor | None = None, past_key_values: NaiveCache | None = None, - key_values_lens: torch.Tensor | None = None, - packed_key_value_indexes: torch.Tensor | None = None, update_past_key_values=True, is_causal=True, mode="und", @@ -842,15 +799,16 @@ def forward( ) for layer_idx, decoder_layer in enumerate(self.layers): + # TODO (Alex): Remove encoder_hidden_states as a kwarg; currently we keep it + # for compatibility with the current custom CacheDiT adapter, as we need to be + # careful to not break the NaiveCache handling when switching from pattern + # 0 -> 4. packed_query_sequence, past_key_values = decoder_layer( hidden_states=packed_query_sequence, encoder_hidden_states=None, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, - packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, **extra_inputs, @@ -918,10 +876,7 @@ def forward( packed_query_sequence: torch.Tensor | None = None, query_lens: torch.Tensor | None = None, packed_query_position_ids: torch.Tensor | None = None, - packed_query_indexes: torch.Tensor | None = None, past_key_values: NaiveCache | None = None, - key_values_lens: torch.Tensor | None = None, - packed_key_value_indexes: torch.Tensor | None = None, update_past_key_values=True, is_causal=True, mode="und", @@ -934,10 +889,7 @@ def forward( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_ids=packed_query_position_ids, - packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, mode=mode, @@ -1225,32 +1177,21 @@ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ packed_text_ids = list() packed_text_position_ids = list() text_token_lens = list() - packed_text_indexes = list() - packed_key_value_indexes = list() - curr = 0 newlens, new_rope = list(), list() for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope): - packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) - curr += curr_kvlen - text_ids = tokenizer.encode(prompt, add_special_tokens=False) text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]] text_token_lens.append(len(text_ids)) packed_text_ids.extend(text_ids) packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids))) - packed_text_indexes.extend(range(curr, curr + len(text_ids))) newlens.append(curr_kvlen + len(text_ids)) new_rope.append(curr_position_id + len(text_ids)) - curr += len(text_ids) generation_input = { "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int), "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long), - "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), - "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), - "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope @@ -1261,9 +1202,6 @@ def forward_cache_update_text( packed_text_ids: torch.IntTensor, packed_text_position_ids: torch.LongTensor, text_token_lens: torch.LongTensor, - packed_text_indexes: torch.LongTensor, - packed_key_value_indexes: torch.LongTensor, - key_values_lens: torch.IntTensor, ): extra_inputs = {} if self.use_moe: @@ -1273,10 +1211,7 @@ def forward_cache_update_text( packed_text_ids=packed_text_ids, query_lens=text_token_lens, packed_query_position_ids=packed_text_position_ids, - packed_query_indexes=packed_text_indexes, past_key_values=past_key_values, - packed_key_value_indexes=packed_key_value_indexes, - key_values_lens=key_values_lens, update_past_key_values=True, is_causal=True, **extra_inputs, @@ -1289,20 +1224,14 @@ def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_tok patchified_vae_latent_shapes, packed_vae_position_ids = list(), list() packed_vae_token_indexes = list() packed_text_ids, packed_text_indexes = list(), list() - packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() - packed_key_value_indexes = list() + packed_seqlens, packed_position_ids = list(), list() - _curr = curr = 0 + _curr = 0 vae_image_tensors = list() newlens, new_rope = list(), list() for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): - packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) - curr += curr_kvlen - packed_text_ids.append(new_token_ids["start_of_image"]) packed_text_indexes.append(_curr) - packed_indexes.append(curr) - curr += 1 _curr += 1 image_tensor = transforms(image) @@ -1321,14 +1250,10 @@ def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_tok num_img_tokens = w * h packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens)) - packed_indexes.extend(range(curr, curr + num_img_tokens)) - curr += num_img_tokens _curr += num_img_tokens packed_text_ids.append(new_token_ids["end_of_image"]) packed_text_indexes.append(_curr) - packed_indexes.append(curr) - curr += 1 _curr += 1 packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) @@ -1352,9 +1277,6 @@ def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_tok "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), - "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), - "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), - "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope @@ -1372,9 +1294,6 @@ def forward_cache_update_vae( packed_text_indexes: torch.LongTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, - packed_indexes: torch.LongTensor, - key_values_lens: torch.IntTensor, - packed_key_value_indexes: torch.Tensor, ): padded_latent = vae_model.encode(padded_images) @@ -1413,10 +1332,7 @@ def forward_cache_update_vae( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, - packed_query_indexes=packed_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=True, is_causal=False, **extra_inputs, @@ -1429,19 +1345,13 @@ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_tok packed_vit_token_indexes = list() vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list() packed_text_ids, packed_text_indexes = list(), list() - packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() - packed_key_value_indexes = list() + packed_seqlens, packed_position_ids = list(), list() - _curr = curr = 0 + _curr = 0 newlens, new_rope = list(), list() for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): - packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) - curr += curr_kvlen - packed_text_ids.append(new_token_ids["start_of_image"]) packed_text_indexes.append(_curr) - packed_indexes.append(curr) - curr += 1 _curr += 1 image_tensor = transforms(image) @@ -1457,14 +1367,10 @@ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_tok packed_vit_position_ids.append(vit_position_ids) vit_token_seqlens.append(num_img_tokens) packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens)) - packed_indexes.extend(range(curr, curr + num_img_tokens)) - curr += num_img_tokens _curr += num_img_tokens packed_text_ids.append(new_token_ids["end_of_image"]) packed_text_indexes.append(_curr) - packed_indexes.append(curr) - curr += 1 _curr += 1 packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) @@ -1481,9 +1387,6 @@ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_tok "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), - "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), - "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), - "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope @@ -1499,9 +1402,6 @@ def forward_cache_update_vit( vit_token_seqlens: torch.IntTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, - packed_indexes: torch.LongTensor, - packed_key_value_indexes: torch.LongTensor, - key_values_lens: torch.IntTensor, ): packed_text_embedding = self.language_model.forward( packed_text_ids=packed_text_ids, @@ -1534,10 +1434,7 @@ def forward_cache_update_vit( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, - packed_query_indexes=packed_indexes, past_key_values=past_key_values, - packed_key_value_indexes=packed_key_value_indexes, - key_values_lens=key_values_lens, update_past_key_values=True, is_causal=False, **extra_inputs, @@ -1549,19 +1446,12 @@ def forward_cache_update_vit( def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None): packed_text_ids, packed_text_indexes = list(), list() packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() - packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list() - packed_key_value_indexes = list() + packed_position_ids, packed_seqlens = list(), list() - query_curr = curr = 0 + query_curr = 0 for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): - packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) - curr += curr_kvlen - packed_text_ids.append(new_token_ids["start_of_image"]) packed_text_indexes.append(query_curr) - - packed_indexes.append(curr) - curr += 1 query_curr += 1 vae_position_ids = self.get_flattened_position_ids( @@ -1575,16 +1465,10 @@ def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None) packed_init_noises.append(torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size**2)) packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens)) packed_seqlens.append(num_image_tokens + 2) - - packed_indexes.extend(range(curr, curr + num_image_tokens)) - curr += num_image_tokens query_curr += num_image_tokens packed_text_ids.append(new_token_ids["end_of_image"]) packed_text_indexes.append(query_curr) - - packed_indexes.append(curr) - curr += 1 query_curr += 1 packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) @@ -1598,9 +1482,6 @@ def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None) "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), - "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), - "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), - "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input @@ -1609,34 +1490,15 @@ def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids) return self.prepare_input(curr_kvlens, curr_rope, image_sizes, new_token_ids) def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes): - packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list() + packed_position_ids = list() - query_curr = curr = 0 for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): - packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) - curr += curr_kvlen - - packed_indexes.append(curr) - curr += 1 - query_curr += 1 - h, w = H // self.latent_downsample, W // self.latent_downsample num_image_tokens = h * w - packed_indexes.extend(range(curr, curr + num_image_tokens)) - curr += num_image_tokens - query_curr += num_image_tokens - - packed_indexes.append(curr) - curr += 1 - query_curr += 1 - packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) generation_input = { "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), - "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), - "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long), - "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input @@ -1663,21 +1525,15 @@ def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): Ported from the original BAGEL ``Bagel.prepare_start_tokens``. """ - packed_start_tokens, packed_key_value_indexes = list(), list() + packed_start_tokens = list() packed_query_position_ids = list() - curr = 0 for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope): - packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) packed_start_tokens.append(new_token_ids["bos_token_id"]) packed_query_position_ids.append(curr_position_id) - curr += curr_kvlen - generation_input = { "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long), "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long), - "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), - "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input @@ -1685,8 +1541,6 @@ def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): def generate_text( self, past_key_values: NaiveCache, - packed_key_value_indexes: torch.LongTensor, - key_values_lens: torch.IntTensor, packed_start_tokens: torch.LongTensor, packed_query_position_ids: torch.LongTensor, max_length: int, @@ -1705,26 +1559,12 @@ def generate_text( while step < max_length: generated_sequence.append(curr_tokens) query_lens = torch.ones_like(curr_tokens) - packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange( - 0, - len(key_values_lens), - device=key_values_lens.device, - dtype=key_values_lens.dtype, - ) - - uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) - for i in range(len(uppacked)): - uppacked[i] += i - packed_key_value_indexes = torch.cat(uppacked, dim=0) output = self.language_model( packed_text_ids=curr_tokens, query_lens=query_lens, packed_query_position_ids=packed_query_position_ids, - packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=True, is_causal=True, mode="und", @@ -1739,13 +1579,6 @@ def generate_text( else: curr_tokens = torch.argmax(pred_logits, dim=-1) - uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) - for i in range(len(uppacked)): - uppacked[i] = torch.cat( - [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0 - ) - packed_key_value_indexes = torch.cat(uppacked, dim=0) - key_values_lens = key_values_lens + 1 packed_query_position_ids = packed_query_position_ids + 1 step += 1 @@ -1764,10 +1597,7 @@ def generate_image( packed_vae_token_indexes: torch.LongTensor, packed_seqlens: torch.IntTensor, packed_position_ids: torch.LongTensor, - packed_indexes: torch.LongTensor, past_key_values: NaiveCache, - key_values_lens: torch.IntTensor, - packed_key_value_indexes: torch.LongTensor, num_timesteps: int = 24, timestep_shift: float = 1.0, cfg_renorm_min: float = 0.0, @@ -1775,18 +1605,12 @@ def generate_image( cfg_interval: tuple[float, float] = [0, 1], # cfg_text cfg_text_scale: float = 1.0, - cfg_text_packed_query_indexes: torch.LongTensor | None = None, cfg_text_packed_position_ids: torch.LongTensor | None = None, cfg_text_past_key_values: NaiveCache | None = None, - cfg_text_key_values_lens: torch.IntTensor | None = None, - cfg_text_packed_key_value_indexes: torch.LongTensor | None = None, # cfg_img cfg_img_scale: float = 1.0, - cfg_img_packed_query_indexes: torch.LongTensor | None = None, cfg_img_packed_position_ids: torch.LongTensor | None = None, cfg_img_past_key_values: NaiveCache | None = None, - cfg_img_key_values_lens: torch.IntTensor | None = None, - cfg_img_packed_key_value_indexes: torch.LongTensor | None = None, return_trajectory_latents: bool = False, scheduler: object | None = None, scheduler_kwargs: dict | None = None, @@ -1823,25 +1647,16 @@ def generate_image( packed_vae_token_indexes=packed_vae_token_indexes, packed_seqlens=packed_seqlens, packed_position_ids=packed_position_ids, - packed_indexes=packed_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, cfg_interval=cfg_interval, cfg_text_scale=cfg_text_scale, - cfg_text_packed_query_indexes=cfg_text_packed_query_indexes, cfg_text_packed_position_ids=cfg_text_packed_position_ids, cfg_text_past_key_values=cfg_text_past_key_values, - cfg_text_key_values_lens=cfg_text_key_values_lens, - cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes, cfg_img_scale=cfg_img_scale, - cfg_img_packed_query_indexes=cfg_img_packed_query_indexes, cfg_img_packed_position_ids=cfg_img_packed_position_ids, cfg_img_past_key_values=cfg_img_past_key_values, - cfg_img_key_values_lens=cfg_img_key_values_lens, - cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes, return_trajectory_latents=return_trajectory_latents, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, @@ -1870,31 +1685,22 @@ def generate_image( v_t = self.forward_single_branch( **common, - packed_indexes=packed_indexes, packed_position_ids=packed_position_ids, - key_values_lens=key_values_lens, past_key_values=past_key_values, - packed_key_value_indexes=packed_key_value_indexes, ) if cfg_text_scale_ > 1.0: cfg_text_v_t = self.forward_single_branch( **common, - packed_indexes=cfg_text_packed_query_indexes, packed_position_ids=cfg_text_packed_position_ids, - key_values_lens=cfg_text_key_values_lens, past_key_values=cfg_text_past_key_values, - packed_key_value_indexes=cfg_text_packed_key_value_indexes, ) cfg_img_v_t = None if cfg_img_scale_ > 1.0: cfg_img_v_t = self.forward_single_branch( **common, - packed_indexes=cfg_img_packed_query_indexes, packed_position_ids=cfg_img_packed_position_ids, - key_values_lens=cfg_img_key_values_lens, past_key_values=cfg_img_past_key_values, - packed_key_value_indexes=cfg_img_packed_key_value_indexes, ) v_t = self._combine_cfg( v_t, @@ -1933,12 +1739,9 @@ def generate_image( packed_vae_position_ids=packed_vae_position_ids, packed_text_ids=packed_text_ids, packed_text_indexes=packed_text_indexes, - packed_indexes=packed_indexes, packed_position_ids=packed_position_ids, packed_seqlens=packed_seqlens, - key_values_lens=key_values_lens, past_key_values=past_key_values, - packed_key_value_indexes=packed_key_value_indexes, ) if scheduler is not None: out = scheduler.step(v_t.to(x_t.device), timesteps[i], x_t, dts[i], **_sched_kw) @@ -1962,46 +1765,25 @@ def generate_image( seq_len = int(packed_seqlens.sum()) # Branch 0: main (gen_context), always present - branches_qi = [packed_indexes] - branches_kvi = [packed_key_value_indexes] - branches_kvl = [key_values_lens] branches_pid = [packed_position_ids] branches_cache = [past_key_values] # Branch 1: cfg_text (unconditional text), always present when use_cfg_text - branches_qi.append(cfg_text_packed_query_indexes) - branches_kvi.append(cfg_text_packed_key_value_indexes) - branches_kvl.append(cfg_text_key_values_lens) branches_pid.append(cfg_text_packed_position_ids) branches_cache.append(cfg_text_past_key_values) # Branch 2: cfg_img (text-only, no image), optional if use_cfg_img: - branches_qi.append(cfg_img_packed_query_indexes) - branches_kvi.append(cfg_img_packed_key_value_indexes) - branches_kvl.append(cfg_img_key_values_lens) branches_pid.append(cfg_img_packed_position_ids) branches_cache.append(cfg_img_past_key_values) num_branches = len(branches_cache) - # Compute per-branch offsets in the merged KV+Q attention tensor - merged_offsets = [0] - for b_idx in range(num_branches): - merged_offsets.append(merged_offsets[-1] + int(branches_kvl[b_idx].sum()) + seq_len) - cfg_batched = { "num_branches": num_branches, "seq_len": seq_len, "batched_query_lens": packed_seqlens.repeat(num_branches), "batched_position_ids": torch.cat(branches_pid), - "batched_kv_lens": torch.cat(branches_kvl), - "batched_query_indexes": torch.cat( - [qi + merged_offsets[b_idx] for b_idx, qi in enumerate(branches_qi)] - ), - "batched_kv_indexes": torch.cat( - [kvi + merged_offsets[b_idx] for b_idx, kvi in enumerate(branches_kvi)] - ), "batched_text_indexes": torch.cat( [packed_text_indexes + b_idx * seq_len for b_idx in range(num_branches)] ), @@ -2030,11 +1812,8 @@ def generate_image( packed_text_ids=packed_text_ids, packed_text_indexes=packed_text_indexes, packed_position_ids=packed_position_ids, - packed_indexes=packed_indexes, packed_seqlens=packed_seqlens, - key_values_lens=key_values_lens, past_key_values=past_key_values, - packed_key_value_indexes=packed_key_value_indexes, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, cfg_text_scale=cfg_text_scale_, @@ -2067,25 +1846,16 @@ def _generate_image_parallel( packed_vae_token_indexes: torch.LongTensor, packed_seqlens: torch.IntTensor, packed_position_ids: torch.LongTensor, - packed_indexes: torch.LongTensor, past_key_values: NaiveCache, - key_values_lens: torch.IntTensor, - packed_key_value_indexes: torch.LongTensor, cfg_renorm_min: float, cfg_renorm_type: str, cfg_interval: tuple[float, float], cfg_text_scale: float, - cfg_text_packed_query_indexes: torch.LongTensor | None, cfg_text_packed_position_ids: torch.LongTensor | None, cfg_text_past_key_values: NaiveCache | None, - cfg_text_key_values_lens: torch.IntTensor | None, - cfg_text_packed_key_value_indexes: torch.LongTensor | None, cfg_img_scale: float, - cfg_img_packed_query_indexes: torch.LongTensor | None, cfg_img_packed_position_ids: torch.LongTensor | None, cfg_img_past_key_values: NaiveCache | None, - cfg_img_key_values_lens: torch.IntTensor | None, - cfg_img_packed_key_value_indexes: torch.LongTensor | None, return_trajectory_latents: bool = False, scheduler: object | None = None, scheduler_kwargs: dict | None = None, @@ -2124,24 +1894,15 @@ def _generate_image_parallel( if cfg_rank == 0: # Gen branch: use main inputs directly branch_position_ids = packed_position_ids - branch_indexes = packed_indexes branch_past_key_values = past_key_values - branch_key_values_lens = key_values_lens - branch_key_value_indexes = packed_key_value_indexes elif cfg_rank == 1: # Text CFG branch branch_position_ids = cfg_text_packed_position_ids - branch_indexes = cfg_text_packed_query_indexes branch_past_key_values = cfg_text_past_key_values - branch_key_values_lens = cfg_text_key_values_lens - branch_key_value_indexes = cfg_text_packed_key_value_indexes elif cfg_rank == 2: # Image CFG branch branch_position_ids = cfg_img_packed_position_ids - branch_indexes = cfg_img_packed_query_indexes branch_past_key_values = cfg_img_past_key_values - branch_key_values_lens = cfg_img_key_values_lens - branch_key_value_indexes = cfg_img_packed_key_value_indexes else: raise RuntimeError(f"Unexpected cfg_rank={cfg_rank} for Bagel 3-branch CFG parallel") @@ -2168,12 +1929,9 @@ def _generate_image_parallel( packed_vae_position_ids=packed_vae_position_ids, packed_text_ids=packed_text_ids, packed_text_indexes=packed_text_indexes, - packed_indexes=branch_indexes, packed_position_ids=branch_position_ids, packed_seqlens=packed_seqlens, - key_values_lens=branch_key_values_lens, past_key_values=branch_past_key_values, - packed_key_value_indexes=branch_key_value_indexes, ) gathered = cfg_group.all_gather(local_v_t, separate_tensors=True) @@ -2195,12 +1953,9 @@ def _generate_image_parallel( packed_vae_position_ids=packed_vae_position_ids, packed_text_ids=packed_text_ids, packed_text_indexes=packed_text_indexes, - packed_indexes=packed_indexes, packed_position_ids=packed_position_ids, packed_seqlens=packed_seqlens, - key_values_lens=key_values_lens, past_key_values=past_key_values, - packed_key_value_indexes=packed_key_value_indexes, ) if scheduler is not None: @@ -2278,12 +2033,9 @@ def forward_single_branch( packed_vae_position_ids: torch.LongTensor, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, - packed_indexes: torch.LongTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, - key_values_lens: torch.IntTensor, past_key_values: NaiveCache, - packed_key_value_indexes: torch.LongTensor, ) -> torch.Tensor: """Run a single-branch forward pass (no CFG batching). @@ -2327,24 +2079,6 @@ def forward_single_branch( x_t_emb = x_t_emb.to(packed_sequence.dtype) packed_sequence[local_vae_indexes] = x_t_emb - # Build local packed_indexes for KV cache merging. - # In the denoising loop packed_indexes is always contiguous - # (arange(kv_len, kv_len + total)), so we can safely build - # the local slice from scratch. - local_total = int(local_seqlens.sum()) - kv_len = int(key_values_lens.sum()) - original_total = int(packed_seqlens.sum()) - assert torch.equal( - packed_indexes, - torch.arange(kv_len, kv_len + original_total, device=packed_indexes.device, dtype=packed_indexes.dtype), - ), "packed_indexes must be contiguous for SP; non-contiguous layout not supported" - local_packed_indexes = torch.arange( - kv_len, - kv_len + local_total, - device=packed_indexes.device, - dtype=packed_indexes.dtype, - ) - extra_inputs = {} if self.use_moe: extra_inputs["mode"] = "gen" @@ -2355,10 +2089,7 @@ def forward_single_branch( packed_query_sequence=packed_sequence, query_lens=local_seqlens, packed_query_position_ids=local_position_ids, - packed_query_indexes=local_packed_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=False, is_causal=False, **extra_inputs, @@ -2394,10 +2125,7 @@ def forward_single_branch( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, - packed_query_indexes=packed_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=False, is_causal=False, **extra_inputs, @@ -2414,12 +2142,9 @@ def forward( packed_vae_position_ids: torch.LongTensor, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, - packed_indexes: torch.LongTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, - key_values_lens: torch.IntTensor, past_key_values: NaiveCache, - packed_key_value_indexes: torch.LongTensor, cfg_renorm_min: float = 0.0, cfg_renorm_type: str = "global", cfg_text_scale: float = 1.0, @@ -2465,10 +2190,7 @@ def forward( packed_query_sequence=batched_sequence, query_lens=cfg_batched["batched_query_lens"], packed_query_position_ids=cfg_batched["batched_position_ids"], - packed_query_indexes=cfg_batched["batched_query_indexes"], past_key_values=cfg_batched["merged_cache"], - key_values_lens=cfg_batched["batched_kv_lens"], - packed_key_value_indexes=cfg_batched["batched_kv_indexes"], update_past_key_values=False, is_causal=False, **extra_inputs, @@ -2501,10 +2223,7 @@ def forward( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, - packed_query_indexes=packed_indexes, past_key_values=past_key_values, - key_values_lens=key_values_lens, - packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=False, is_causal=False, **extra_inputs, diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 2363a690a07..a62d2a75ad4 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -793,13 +793,7 @@ def vae_transforms(img): cfg_renorm_type=gen_params.cfg_renorm_type, **generation_input, cfg_text_packed_position_ids=generation_input_cfg_text["cfg_packed_position_ids"], - cfg_text_packed_query_indexes=generation_input_cfg_text["cfg_packed_query_indexes"], - cfg_text_key_values_lens=generation_input_cfg_text["cfg_key_values_lens"], - cfg_text_packed_key_value_indexes=generation_input_cfg_text["cfg_packed_key_value_indexes"], cfg_img_packed_position_ids=generation_input_cfg_img["cfg_packed_position_ids"], - cfg_img_packed_query_indexes=generation_input_cfg_img["cfg_packed_query_indexes"], - cfg_img_key_values_lens=generation_input_cfg_img["cfg_key_values_lens"], - cfg_img_packed_key_value_indexes=generation_input_cfg_img["cfg_packed_key_value_indexes"], return_trajectory_latents=req.sampling_params.return_trajectory_latents, scheduler=self.scheduler, scheduler_kwargs=self.scheduler_kwargs,