From 6a9a21a91cda265333a18a1c1765bd7eeed6f3fb Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 2 Jun 2026 05:27:19 +0000 Subject: [PATCH 01/12] bagel fix wip Signed-off-by: Alex Brooks minor Signed-off-by: Alex Brooks fix ref Signed-off-by: Alex Brooks --- .../models/bagel/bagel_transformer.py | 93 +++++++++++++++++-- 1 file changed, 83 insertions(+), 10 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 4231d4cc638..447dc7d38d0 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -306,6 +306,10 @@ class NaiveCache: def __init__(self, num_layers): self.key_cache = {k: None for k in range(num_layers)} self.value_cache = {k: None for k in range(num_layers)} + # Track kv_lens; we need this because we pack the forward passes + # for CFG into a single forward call and the kv length may be different, + # e.g., due to 0 kvs for text_cfg path and nonzero for others + self.key_values_lens: list[int] | None = None @property def num_layers(self): @@ -431,6 +435,7 @@ def _is_sp_active(self) -> bool: def _forward_gen( self, packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, past_key_values: NaiveCache | None, packed_vae_token_indexes: torch.Tensor, @@ -455,6 +460,7 @@ def _forward_gen( Currently we shouldn't need it in the model, and it would be ideal to handle packing/batching etc in a more model agnostic way. """ + cache_k = cache_v = None packed_query_sequence = packed_query_sequence.to(torch.bfloat16) packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] @@ -505,7 +511,9 @@ def _forward_gen( if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: cache_k = past_key_values.key_cache[self.layer_idx] cache_v = past_key_values.value_cache[self.layer_idx] - ctx_k = torch.cat([cache_k, text_k], dim=0) + ctx_k = torch.cat( + [cache_k, text_k], dim=0 + ) # we are catting a [20,4,128] to [6, 4, 128] -> [26, 4, 128] but not sure why we have 20 ctx_v = torch.cat([cache_v, text_v], dim=0) else: ctx_k = text_k @@ -527,15 +535,63 @@ def _forward_gen( ), ) else: - q = torch.cat([text_q, vae_q], dim=0).unsqueeze(0) - k = torch.cat([ctx_k, vae_k], dim=0).unsqueeze(0) - v = torch.cat([ctx_v, vae_v], dim=0).unsqueeze(0) - attn_out = self.attn_noncausal(q, k, v) + num_branches = len(query_lens) + text_per_branch = text_q.shape[0] // num_branches + vae_per_branch = vae_q.shape[0] // num_branches + + text_q_parts = text_q.split([text_per_branch] * num_branches) + vae_q_parts = vae_q.split([vae_per_branch] * num_branches) + text_k_parts = text_k.split([text_per_branch] * num_branches) + vae_k_parts = vae_k.split([vae_per_branch] * num_branches) + text_v_parts = text_v.split([text_per_branch] * num_branches) + vae_v_parts = vae_v.split([vae_per_branch] * num_branches) + + q_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_q_parts, vae_q_parts)]) # [3, 4098, 28, 128] + + if cache_k is not None: + kv_lens = getattr(past_key_values, "key_values_lens", None) + if kv_lens is None: + per_branch = cache_k.shape[0] // num_branches + kv_lens = [per_branch] * num_branches + nonzero = [kv_len for kv_len in kv_lens if kv_len > 0] + ck_parts = list(cache_k.split(nonzero)) if nonzero else [] + cv_parts = list(cache_v.split(nonzero)) if nonzero else [] + ci = 0 + k_branches, v_branches = [], [] + for i in range(num_branches): + kp, vp = [], [] + if kv_lens[i] > 0: + kp.append( + ck_parts[ci] + ) # This is causing issues because it's [10, 0, 10], but why do we have uneven kv cache? + vp.append(cv_parts[ci]) + ci += 1 + kp += [text_k_parts[i], vae_k_parts[i]] + vp += [text_v_parts[i], vae_v_parts[i]] + k_branches.append(torch.cat(kp)) + v_branches.append(torch.cat(vp)) + + max_k = max(b.shape[0] for b in k_branches) + k_4d = cache_k.new_zeros(num_branches, max_k, self.num_kv_heads, self.head_dim) + v_4d = cache_v.new_zeros(num_branches, max_k, self.num_kv_heads, self.head_dim) + mask = torch.zeros(num_branches, max_k, dtype=torch.bool, device=cache_k.device) + for i in range(num_branches): + klen = k_branches[i].shape[0] + pad = max_k - klen + k_4d[i, pad:] = k_branches[i] + v_4d[i, pad:] = v_branches[i] + mask[i, pad:] = True + metadata = DiffusionAttentionMetadata(attn_mask=mask) if torch.any(~mask) else None + else: + k_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_k_parts, vae_k_parts)]) + v_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_v_parts, vae_v_parts)]) + metadata = None - text_len = text_q.shape[0] - attn_out = attn_out.squeeze(0) - text_attn = attn_out[:text_len].reshape(text_len, self.q_size) - vae_attn = attn_out[text_len:].reshape(-1, self.q_size) + attn_out = self.attn_noncausal(q_4d, k_4d, v_4d, metadata) + q_per_branch = int(query_lens[0]) + attn_out = attn_out.reshape(num_branches, q_per_branch, self.q_size) + text_attn = attn_out[:, :text_per_branch].reshape(-1, self.q_size) + vae_attn = attn_out[:, text_per_branch:].reshape(-1, self.q_size) # Apply output projections text_out, _ = self.o_proj(text_attn) @@ -675,6 +731,7 @@ def forward( raise ValueError("Generation model for Bagel requires non-causal attention") return self._forward_gen( packed_query_sequence=packed_query_sequence, + query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, past_key_values=past_key_values, packed_vae_token_indexes=packed_vae_token_indexes, @@ -874,7 +931,7 @@ def forward( if mode == "gen": assert packed_vae_token_indexes is not None assert packed_text_indexes is not None - extra_inputs.update( + extra_inputs.update( # we have 3 <4096> tokens packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, ) @@ -1584,6 +1641,22 @@ def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes): return generation_input + @staticmethod + def _merge_naive_caches(caches: list) -> NaiveCache: + """Merge multiple NaiveCache objects by concatenating KV tensors per layer.""" + if not caches: + return NaiveCache(0) + + num_layers = len(caches[0].key_cache) + merged = NaiveCache(num_layers) + for layer_idx in range(num_layers): + key_parts = [c.key_cache[layer_idx] for c in caches if c.key_cache[layer_idx] is not None] + val_parts = [c.value_cache[layer_idx] for c in caches if c.value_cache[layer_idx] is not None] + merged.key_cache[layer_idx] = torch.cat(key_parts, dim=0) if key_parts else None + merged.value_cache[layer_idx] = torch.cat(val_parts, dim=0) if val_parts else None + merged.key_values_lens = [c.key_cache[0].shape[0] if c.key_cache[0] is not None else 0 for c in caches] + return merged + def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): """Prepare start tokens for autoregressive text generation. From 7a3fb269929331c6f20716ed78c799271e66ef38 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 2 Jun 2026 06:06:22 +0000 Subject: [PATCH 02/12] revert lance changes (passing bagel tests) Signed-off-by: Alex Brooks --- .../diffusion/models/bagel/bagel_transformer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 447dc7d38d0..0e0f7598063 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -1766,11 +1766,13 @@ def generate_image( frame_condition_token_indexes = frame_condition_token_indexes.to(x_t.device).long() pinned_x_t = x_t[frame_condition_token_indexes].clone() - # Use num_timesteps + 1 sample points so we get `num_timesteps` denoise - # steps after dropping the terminal t=0 (which has no dt). Upstream - # Lance / BAGEL both use this convention; without the +1 we silently - # run one fewer denoise iteration than the user asked for. - timesteps = torch.linspace(1, 0, num_timesteps + 1, device=x_t.device) + # TODO: Re-enable with new reference pixels in Bagel tests + # # Use num_timesteps + 1 sample points so we get `num_timesteps` denoise + # # steps after dropping the terminal t=0 (which has no dt). Upstream + # # Lance / BAGEL both use this convention; without the +1 we silently + # # run one fewer denoise iteration than the user asked for. + # timesteps = torch.linspace(1, 0, num_timesteps + 1, device=x_t.device) + timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device) timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) dts = timesteps[:-1] - timesteps[1:] timesteps = timesteps[:-1] From cd3935ad97c9d11f6005ac4669b874594ebdb817 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 2 Jun 2026 06:47:28 +0000 Subject: [PATCH 03/12] pop kwarg to fix lance tests Signed-off-by: Alex Brooks --- vllm_omni/diffusion/models/bagel/pipeline_bagel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 56a5b0dcc40..f33b806f859 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -553,6 +553,8 @@ def vae_transforms(img): for k, v in gen_input_img.items(): if torch.is_tensor(v): gen_input_img[k] = v.to(self.device) + for k in ("packed_indexes", "packed_key_value_indexes", "key_values_lens"): + gen_input_img.pop(k, None) with torch.autocast( device_type=self.device.type, enabled=self.device.type != "cpu", From de884d51f017f2c096793dfbf199812607c60e4f Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 2 Jun 2026 06:52:07 +0000 Subject: [PATCH 04/12] remove merge Signed-off-by: Alex Brooks --- .../diffusion/models/bagel/bagel_transformer.py | 17 +---------------- .../diffusion/models/bagel/pipeline_bagel.py | 2 +- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 0e0f7598063..7dbcbffcd94 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -1516,6 +1516,7 @@ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_tok newlens.append(curr_kvlen + num_img_tokens + 2) new_rope.append(curr_position_id + 1) + # TODO - 棄用 (deprecated) kwargs should just be removed here so we do not need to pop them later generation_input = { "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), @@ -1641,22 +1642,6 @@ def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes): return generation_input - @staticmethod - def _merge_naive_caches(caches: list) -> NaiveCache: - """Merge multiple NaiveCache objects by concatenating KV tensors per layer.""" - if not caches: - return NaiveCache(0) - - num_layers = len(caches[0].key_cache) - merged = NaiveCache(num_layers) - for layer_idx in range(num_layers): - key_parts = [c.key_cache[layer_idx] for c in caches if c.key_cache[layer_idx] is not None] - val_parts = [c.value_cache[layer_idx] for c in caches if c.value_cache[layer_idx] is not None] - merged.key_cache[layer_idx] = torch.cat(key_parts, dim=0) if key_parts else None - merged.value_cache[layer_idx] = torch.cat(val_parts, dim=0) if val_parts else None - merged.key_values_lens = [c.key_cache[0].shape[0] if c.key_cache[0] is not None else 0 for c in caches] - return merged - def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): """Prepare start tokens for autoregressive text generation. diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index f33b806f859..9f950970ce7 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -606,7 +606,7 @@ def vae_transforms(img): # cfg_text_context: update with negative prompt (no text condition). # When empty, keep cfg_text_context as-is (kv_lens=0) to match - # original BAGEL; _merge_naive_caches handles None KV entries. + # original BAGEL. neg_prompt = extra_args.get("negative_prompt", "") if neg_prompt: neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts( From 9804e8d7e41b42478c9d9d5a1d96ab0d4dc9eda7 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 2 Jun 2026 16:27:59 +0000 Subject: [PATCH 05/12] more consolidation and refactoring Signed-off-by: Alex Brooks --- tests/diffusion/utils/__init__.py | 0 tests/diffusion/utils/test_kv_utils.py | 75 ++++++++++ .../models/bagel/bagel_transformer.py | 133 +++++++++++------- vllm_omni/diffusion/utils/kv_utils.py | 39 +++++ 4 files changed, 200 insertions(+), 47 deletions(-) create mode 100644 tests/diffusion/utils/__init__.py create mode 100644 tests/diffusion/utils/test_kv_utils.py create mode 100644 vllm_omni/diffusion/utils/kv_utils.py diff --git a/tests/diffusion/utils/__init__.py b/tests/diffusion/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/utils/test_kv_utils.py b/tests/diffusion/utils/test_kv_utils.py new file mode 100644 index 00000000000..6b523b92033 --- /dev/null +++ b/tests/diffusion/utils/test_kv_utils.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for kv utils.""" + +import pytest +import torch + +from vllm_omni.diffusion.utils.kv_utils import left_pad_stack + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_uniform_lengths(): + """Ensure uniform 3D tensors correctly stack to 4D and have no mask.""" + tensors = [torch.randn(10, 4, 128) for _ in range(3)] + stacked, mask = left_pad_stack(tensors) + assert stacked.shape == (3, 10, 4, 128) + assert mask is None + for i in range(3): + assert torch.equal(stacked[i], tensors[i]) + + +def test_variable_lengths(): + """Ensure variable 3D tensors correctly stack to 4D with a mask.""" + t1 = torch.ones(5, 2, 4) + t2 = torch.ones(8, 2, 4) + t3 = torch.ones(3, 2, 4) + stacked, mask = left_pad_stack([t1, t2, t3]) + + assert stacked.shape == (3, 8, 2, 4) + assert mask is not None + assert mask.shape == (3, 8) + # Ensure summing over dim 1 gives our seq lens back + assert mask.sum(dim=1).tolist() == [5, 8, 3] + + +def test_single_tensor_is_4d(): + """Ensure a single 3D tensor is expanded to 4D.""" + t = torch.randn(7, 4, 128) + stacked, mask = left_pad_stack([t]) + assert stacked.shape == (1, 7, 4, 128) + assert mask is None + assert torch.equal(stacked[0], t) + + +def test_preserves_device_and_dtype(): + """Ensure device/dtype is preserved.""" + t1 = torch.randn(3, 2, dtype=torch.bfloat16) + t2 = torch.randn(5, 2, dtype=torch.bfloat16) + stacked, mask = left_pad_stack([t1, t2]) + assert stacked.dtype == torch.bfloat16 + assert stacked.device == t1.device + assert mask.device == t1.device + + +def test_mismatched_trailing_shapes_raises(): + """Ensure that mismatched dims outside of 0 explodes.""" + t1 = torch.randn(5, 4, 128) + t2 = torch.randn(5, 8, 128) + with pytest.raises(ValueError): + left_pad_stack([t1, t2]) + + +def test_mismatched_ndim_raises(): + """Ensure that mismatched ndims explodes.""" + t1 = torch.randn(5, 4) + t2 = torch.randn(5, 4, 128) + with pytest.raises(ValueError): + left_pad_stack([t1, t2]) + + +def test_empty_list_raises(): + """Ensure that tensors must be nonempty.""" + with pytest.raises(ValueError): + left_pad_stack([]) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 7dbcbffcd94..6afba92452e 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -7,7 +7,7 @@ # Original file was released under Apache-2.0, with the full license text # available at https://github.com/huggingface/transformers/blob/main/LICENSE. -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from dataclasses import dataclass import numpy as np @@ -46,6 +46,7 @@ ) from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.utils.kv_utils import left_pad_stack from vllm_omni.model_executor.layers.timestep_embedding import timestep_embedding @@ -322,6 +323,57 @@ def seq_lens(self): else: return 0 + @staticmethod + def merge(caches: Sequence["NaiveCache"]) -> "NaiveCache": + """Merge per-branch NaiveCaches into one for batched attention; + this lets us do the forward passes for CFG in one batched pass, + although it's worth noting that this is currently only used + for single request. We need this so that gen mode knows the + respective kv lengths, and can split things back out as needed. + """ + num_layers = caches[0].num_layers + merged = NaiveCache(num_layers) + lens = [c.seq_lens for c in caches] + merged.key_values_lens = lens + + nonempty = [c for c in caches if c.key_cache[0] is not None] + if not nonempty: + return merged + + for layer in range(num_layers): + merged.key_cache[layer] = torch.cat([c.key_cache[layer] for c in nonempty], dim=0) + merged.value_cache[layer] = torch.cat([c.value_cache[layer] for c in nonempty], dim=0) + + return merged + + @staticmethod + def split_with_zeros( + tensor: torch.Tensor, + lengths: Sequence[int], + ) -> list[torch.Tensor | None]: + """Split tensor by lengths, which may include 0 entries, e.g., for splitting cfg + branches out, since text_cfg may have 0 kv length. + + 0 lengths will be replaced with None in the returned list. + """ + # Ensure that the lengths are all nonzero and sum to the first dim of our tensor + if not all(isinstance(ln, int) and ln >= 0 for ln in lengths): + raise ValueError("split lengths must be greater than or equal to zero") + + expected = sum(ln for ln in lengths if ln > 0) + if tensor.shape[0] != expected: + raise ValueError(f"tensor dim 0 ({tensor.shape[0]}) != sum of nonzero lengths ({expected})") + + result: list[torch.Tensor | None] = [] + offset = 0 + for ln in lengths: + if ln > 0: + result.append(tensor[offset : offset + ln]) + offset += ln + else: + result.append(None) + return result + @dataclass class BaseNavitOutputWithPast(ModelOutput): @@ -546,42 +598,29 @@ def _forward_gen( text_v_parts = text_v.split([text_per_branch] * num_branches) vae_v_parts = vae_v.split([vae_per_branch] * num_branches) - q_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_q_parts, vae_q_parts)]) # [3, 4098, 28, 128] + # Query lengths should not be variable since we + # just split above, so we just concat + stack to 4D + q_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_q_parts, vae_q_parts)]) - if cache_k is not None: + if cache_k is not None and cache_v is not None: kv_lens = getattr(past_key_values, "key_values_lens", None) if kv_lens is None: per_branch = cache_k.shape[0] // num_branches kv_lens = [per_branch] * num_branches - nonzero = [kv_len for kv_len in kv_lens if kv_len > 0] - ck_parts = list(cache_k.split(nonzero)) if nonzero else [] - cv_parts = list(cache_v.split(nonzero)) if nonzero else [] - ci = 0 - k_branches, v_branches = [], [] - for i in range(num_branches): - kp, vp = [], [] - if kv_lens[i] > 0: - kp.append( - ck_parts[ci] - ) # This is causing issues because it's [10, 0, 10], but why do we have uneven kv cache? - vp.append(cv_parts[ci]) - ci += 1 - kp += [text_k_parts[i], vae_k_parts[i]] - vp += [text_v_parts[i], vae_v_parts[i]] - k_branches.append(torch.cat(kp)) - v_branches.append(torch.cat(vp)) - - max_k = max(b.shape[0] for b in k_branches) - k_4d = cache_k.new_zeros(num_branches, max_k, self.num_kv_heads, self.head_dim) - v_4d = cache_v.new_zeros(num_branches, max_k, self.num_kv_heads, self.head_dim) - mask = torch.zeros(num_branches, max_k, dtype=torch.bool, device=cache_k.device) - for i in range(num_branches): - klen = k_branches[i].shape[0] - pad = max_k - klen - k_4d[i, pad:] = k_branches[i] - v_4d[i, pad:] = v_branches[i] - mask[i, pad:] = True - metadata = DiffusionAttentionMetadata(attn_mask=mask) if torch.any(~mask) else None + + ck_per_branch = NaiveCache.split_with_zeros(cache_k, kv_lens) + cv_per_branch = NaiveCache.split_with_zeros(cache_v, kv_lens) + k_branches = [ + torch.cat([t for t in (ck_per_branch[i], text_k_parts[i], vae_k_parts[i]) if t is not None]) + for i in range(num_branches) + ] + v_branches = [ + torch.cat([t for t in (cv_per_branch[i], text_v_parts[i], vae_v_parts[i]) if t is not None]) + for i in range(num_branches) + ] + k_4d, mask = left_pad_stack(k_branches) + v_4d, _ = left_pad_stack(v_branches) + metadata = DiffusionAttentionMetadata(attn_mask=mask) if mask is not None else None else: k_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_k_parts, vae_k_parts)]) v_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_v_parts, vae_v_parts)]) @@ -1914,15 +1953,15 @@ def generate_image( # Each CFG branch runs its own LLM forward; we just need the # per-branch packed_position_ids and past_key_values for # ``Bagel.forward`` to dispatch through. - cfg_branches: dict | None = None + cfg_branch_pids: list[torch.Tensor] | None = None + cfg_branch_caches: list[NaiveCache] | None = None if use_cfg_text: - branches_pid = [packed_position_ids, cfg_text_packed_position_ids] - branches_cache = [past_key_values, cfg_text_past_key_values] + cfg_branch_pids = [packed_position_ids, cfg_text_packed_position_ids] + cfg_branch_caches = [past_key_values, cfg_text_past_key_values] if use_cfg_img: - branches_pid.append(cfg_img_packed_position_ids) - branches_cache.append(cfg_img_past_key_values) - cfg_branches = {"pids": branches_pid, "caches": branches_cache} + cfg_branch_pids.append(cfg_img_packed_position_ids) + cfg_branch_caches.append(cfg_img_past_key_values) if return_trajectory_latents and len(timesteps) > 0: trajectory_latents.append(x_t.clone()) @@ -1955,7 +1994,8 @@ def generate_image( cfg_renorm_type=cfg_renorm_type, cfg_text_scale=cfg_text_scale_, cfg_img_scale=cfg_img_scale_, - cfg_branches=cfg_branches, + cfg_branch_pids=cfg_branch_pids, + cfg_branch_caches=cfg_branch_caches, ) if scheduler is not None: @@ -2299,7 +2339,8 @@ def forward( cfg_renorm_type: str = "global", cfg_text_scale: float = 1.0, cfg_img_scale: float = 1.0, - cfg_branches: dict | None = None, + cfg_branch_pids: list[torch.Tensor] | None = None, + cfg_branch_caches: list[NaiveCache] | None = None, ): # Build query sequence (identical for all CFG branches) packed_text_embedding = self.language_model.forward( @@ -2327,7 +2368,7 @@ def forward( cfg_text_v_t = None cfg_img_v_t = None - if use_cfg and cfg_branches is not None: + if use_cfg and cfg_branch_pids is not None and cfg_branch_caches is not None: # Sequential per-branch CFG forwards (matches upstream lance.py). # The previous batched path concatenated cond + cfg into one LLM # forward, but the block-diagonal attention mask was lost when @@ -2346,12 +2387,10 @@ def _run_branch(branch_pkv, branch_pids): ) return self.llm2vae(out.packed_query_sequence)[packed_vae_token_indexes] - branches_pids = cfg_branches["pids"] - branches_caches = cfg_branches["caches"] - v_t = _run_branch(branches_caches[0], branches_pids[0]) - cfg_text_v_t = _run_branch(branches_caches[1], branches_pids[1]) - if cfg_img_scale > 1.0 and len(branches_caches) > 2: - cfg_img_v_t = _run_branch(branches_caches[2], branches_pids[2]) + v_t = _run_branch(cfg_branch_caches[0], cfg_branch_pids[0]) + cfg_text_v_t = _run_branch(cfg_branch_caches[1], cfg_branch_pids[1]) + if cfg_img_scale > 1.0 and len(cfg_branch_caches) > 2: + cfg_img_v_t = _run_branch(cfg_branch_caches[2], cfg_branch_pids[2]) else: # Single forward (no CFG or outside cfg_interval). output = self.language_model.forward( diff --git a/vllm_omni/diffusion/utils/kv_utils.py b/vllm_omni/diffusion/utils/kv_utils.py new file mode 100644 index 00000000000..78cfd524b1b --- /dev/null +++ b/vllm_omni/diffusion/utils/kv_utils.py @@ -0,0 +1,39 @@ +"""Utilities for batching variable-length tensors in diffusion attention.""" + +import torch + + +def left_pad_stack( + tensors: list[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Left-pad and stack variable-length tensors. Only dim 0 may vary, + and it's assumed that all tensors are the same dtype & use the same + device. + + Returns (stacked, mask) where mask is a 2D boolean mask, and both + tensors are on the device of the provided tensors. If all tensors are + the same length, None is returned for mask. + """ + trailing_dims = set([ts.shape[1:] for ts in tensors]) + if len(trailing_dims) != 1: + raise ValueError("Tensors must be non-empty and can only vary in dim 0") + trailing = trailing_dims.pop() + + seq_lens = [ts.shape[0] for ts in tensors] + max_len = max(seq_lens) + + # If everything is the same length, we don't need a mask / varlen + if all(sl == max_len for sl in seq_lens): + return torch.stack(tensors), None + + device = tensors[0].device + dtype = tensors[0].dtype + stacked = torch.zeros(len(tensors), max_len, *trailing, dtype=dtype, device=device) + # Create the boolean mask for the input sequences + mask = torch.zeros(len(tensors), max_len, dtype=torch.bool, device=device) + for idx, (ts, sl) in enumerate(zip(tensors, seq_lens)): + pad = max_len - sl + stacked[idx, pad:] = ts + mask[idx, pad:] = True + + return stacked, mask From 8068161b05a990bfc72b770bb4d5de99b287839c Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 2 Jun 2026 23:20:19 +0000 Subject: [PATCH 06/12] add naivecache tests Signed-off-by: Alex Brooks --- .../models/bagel/test_naive_cache.py | 195 ++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 tests/diffusion/models/bagel/test_naive_cache.py diff --git a/tests/diffusion/models/bagel/test_naive_cache.py b/tests/diffusion/models/bagel/test_naive_cache.py new file mode 100644 index 00000000000..4f71934e8ad --- /dev/null +++ b/tests/diffusion/models/bagel/test_naive_cache.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for NaiveCache merge/split logic used in batched CFG.""" + +import pytest +import torch + +from vllm_omni.diffusion.models.bagel.bagel_transformer import NaiveCache + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +NUM_LAYERS = 2 +NUM_KV_HEADS = 4 +HEAD_DIM = 8 + + +def _make_cache(num_layers, seq_len, num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM, seed=0): + """Create a NaiveCache with deterministic random data. seq_len=0 returns an empty cache.""" + gen = torch.Generator().manual_seed(seed) + cache = NaiveCache(num_layers) + if seq_len == 0: + return cache + for layer in range(num_layers): + cache.key_cache[layer] = torch.randn(seq_len, num_kv_heads, head_dim, generator=gen) + cache.value_cache[layer] = torch.randn(seq_len, num_kv_heads, head_dim, generator=gen) + return cache + + +# ── Basics ── + + +def test_init_creates_none_entries(): + """Ensure the NaiveCache is initialized with None values per layer.""" + cache = NaiveCache(NUM_LAYERS) + assert cache.num_layers == NUM_LAYERS + for layer in range(NUM_LAYERS): + assert cache.key_cache[layer] is None + assert cache.value_cache[layer] is None + + +@pytest.mark.parametrize("seq_len", [0, 10]) +def test_seq_lens_empty(seq_len): + """Ensure that by default, we have 0 seq lens.""" + cache = _make_cache(NUM_LAYERS, seq_len=seq_len) + assert cache.seq_lens == seq_len + assert cache.num_layers == NUM_LAYERS + + +# ── Merge ── + + +def test_merge_two_equal_length(): + """Ensure that we can merge two NaiveCaches that are identically shaped.""" + c0 = _make_cache(NUM_LAYERS, seq_len=5, seed=0) + c1 = _make_cache(NUM_LAYERS, seq_len=5, seed=1) + merged = NaiveCache.merge([c0, c1]) + + assert merged.key_values_lens == [5, 5] + for layer in range(NUM_LAYERS): + assert merged.key_cache[layer].shape[0] == 10 + # the merged cache will just have keys and values per layer concatenated + assert torch.equal(merged.key_cache[layer][:5], c0.key_cache[layer]) + assert torch.equal(merged.key_cache[layer][5:], c1.key_cache[layer]) + assert torch.equal(merged.value_cache[layer][:5], c0.value_cache[layer]) + assert torch.equal(merged.value_cache[layer][5:], c1.value_cache[layer]) + + +def test_merge_three_zero_len_cache(): + """Ensure we handle zero length cache correctly in merge.""" + # NOTE: This is relevant for text_cfg in Bagel, which has a len of 0 by default + gen_cache = _make_cache(NUM_LAYERS, seq_len=10, seed=0) + text_cfg_cache = _make_cache(NUM_LAYERS, seq_len=0) + img_cfg_cache = _make_cache(NUM_LAYERS, seq_len=7, seed=2) + merged = NaiveCache.merge([gen_cache, text_cfg_cache, img_cfg_cache]) + + assert merged.key_values_lens == [10, 0, 7] + for layer in range(NUM_LAYERS): + assert merged.key_cache[layer].shape[0] == 17 + assert torch.equal(merged.key_cache[layer][:10], gen_cache.key_cache[layer]) + assert torch.equal(merged.key_cache[layer][10:], img_cfg_cache.key_cache[layer]) + + +def test_merge_all_empty(): + """Ensure that merging empty caches is well defined.""" + caches = [_make_cache(NUM_LAYERS, seq_len=0) for _ in range(3)] + merged = NaiveCache.merge(caches) + + assert merged.key_values_lens == [0, 0, 0] + for layer in range(NUM_LAYERS): + assert merged.key_cache[layer] is None + assert merged.value_cache[layer] is None + + +def test_merge_single_cache(): + """Ensure merging one cache returns an identical cache.""" + c = _make_cache(NUM_LAYERS, seq_len=8, seed=42) + merged = NaiveCache.merge([c]) + + assert merged.key_values_lens == [8] + for layer in range(NUM_LAYERS): + assert torch.equal(merged.key_cache[layer], c.key_cache[layer]) + + +def test_merge_preserves_dtype(): + """Ensure merging doesn't modify dtypes.""" + cache = NaiveCache(NUM_LAYERS) + for layer in range(NUM_LAYERS): + cache.key_cache[layer] = torch.randn(5, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16) + cache.value_cache[layer] = torch.randn(5, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16) + + merged = NaiveCache.merge([cache]) + assert merged.key_cache[0].dtype == torch.bfloat16 + assert merged.value_cache[0].dtype == torch.bfloat16 + + +# ── split_with_zeros ── + + +def test_split_all_nonzero(): + """Ensure NaiveCache splits in the simple case (all lens nonzero).""" + t = torch.randn(15, NUM_KV_HEADS, HEAD_DIM) + parts = NaiveCache.split_with_zeros(t, [5, 4, 6]) + + assert len(parts) == 3 + assert all(p is not None for p in parts) + assert parts[0].shape[0] == 5 + assert parts[1].shape[0] == 4 + assert parts[2].shape[0] == 6 + assert torch.equal(torch.cat(parts), t) + + +def test_split_with_zero(): + """Ensure NaiveCache split handles zero length correctly (used in Bagel).""" + t = torch.randn(17, NUM_KV_HEADS, HEAD_DIM) + parts = NaiveCache.split_with_zeros(t, [10, 0, 7]) + + assert parts[0].shape[0] == 10 + assert parts[1] is None + assert parts[2].shape[0] == 7 + assert torch.equal(torch.cat([parts[0], parts[2]]), t) + + +def test_split_wrong_sum_raises(): + """Ensure NaiveCache raises if splits don't match the sum of dims on axis 0.""" + t = torch.randn(10, NUM_KV_HEADS, HEAD_DIM) + with pytest.raises(ValueError, match="dim 0"): + NaiveCache.split_with_zeros(t, [5, 3]) + + +def test_split_negative_length_raises(): + """Ensure NaiveCache raises if splits have any negative values.""" + t = torch.randn(10, NUM_KV_HEADS, HEAD_DIM) + with pytest.raises(ValueError, match="greater than or equal to zero"): + NaiveCache.split_with_zeros(t, [5, -1, 6]) + + +def test_split_preserves_dtype(): + """Ensure NaiveCache split preserves dtype.""" + t = torch.randn(10, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16) + parts = NaiveCache.split_with_zeros(t, [4, 6]) + assert parts[0].dtype == torch.bfloat16 + assert parts[1].dtype == torch.bfloat16 + + +def test_round_trip_two_populated(): + """Roundtrip test for merging and resplitting two simple caches.""" + c0 = _make_cache(NUM_LAYERS, seq_len=5, seed=0) + c1 = _make_cache(NUM_LAYERS, seq_len=8, seed=1) + merged = NaiveCache.merge([c0, c1]) + + for layer in range(NUM_LAYERS): + k_parts = NaiveCache.split_with_zeros(merged.key_cache[layer], merged.key_values_lens) + v_parts = NaiveCache.split_with_zeros(merged.value_cache[layer], merged.key_values_lens) + assert torch.equal(k_parts[0], c0.key_cache[layer]) + assert torch.equal(k_parts[1], c1.key_cache[layer]) + assert torch.equal(v_parts[0], c0.value_cache[layer]) + assert torch.equal(v_parts[1], c1.value_cache[layer]) + + +def test_round_trip_three_branches_with_zero_cfg(): + """Roundtrip test with a zero entry (i.e., same as Bagel's gen/text_cfg/img_cfg case).""" + gen_cache = _make_cache(NUM_LAYERS, seq_len=10, seed=0) + text_cfg_cache = _make_cache(NUM_LAYERS, seq_len=0) + img_cfg_cache = _make_cache(NUM_LAYERS, seq_len=7, seed=2) + merged = NaiveCache.merge([gen_cache, text_cfg_cache, img_cfg_cache]) + + for layer in range(NUM_LAYERS): + k_parts = NaiveCache.split_with_zeros(merged.key_cache[layer], merged.key_values_lens) + v_parts = NaiveCache.split_with_zeros(merged.value_cache[layer], merged.key_values_lens) + assert torch.equal(k_parts[0], gen_cache.key_cache[layer]) + assert k_parts[1] is None + assert torch.equal(k_parts[2], img_cfg_cache.key_cache[layer]) + assert torch.equal(v_parts[0], gen_cache.value_cache[layer]) + assert v_parts[1] is None + assert torch.equal(v_parts[2], img_cfg_cache.value_cache[layer]) From cb8959379d3e24a4ffa045c111d51f1eed74ec6b Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 3 Jun 2026 07:23:02 +0000 Subject: [PATCH 07/12] add from obj tests for naive cache (for kv transfer) Signed-off-by: Alex Brooks --- .../models/bagel/test_naive_cache.py | 46 +++++++++++++++---- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/tests/diffusion/models/bagel/test_naive_cache.py b/tests/diffusion/models/bagel/test_naive_cache.py index 4f71934e8ad..26ad0daeede 100644 --- a/tests/diffusion/models/bagel/test_naive_cache.py +++ b/tests/diffusion/models/bagel/test_naive_cache.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Unit tests for NaiveCache merge/split logic used in batched CFG.""" +from types import SimpleNamespace + import pytest import torch @@ -26,9 +28,6 @@ def _make_cache(num_layers, seq_len, num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DI return cache -# ── Basics ── - - def test_init_creates_none_entries(): """Ensure the NaiveCache is initialized with None values per layer.""" cache = NaiveCache(NUM_LAYERS) @@ -46,9 +45,7 @@ def test_seq_lens_empty(seq_len): assert cache.num_layers == NUM_LAYERS -# ── Merge ── - - +### Merge tests def test_merge_two_equal_length(): """Ensure that we can merge two NaiveCaches that are identically shaped.""" c0 = _make_cache(NUM_LAYERS, seq_len=5, seed=0) @@ -113,9 +110,7 @@ def test_merge_preserves_dtype(): assert merged.value_cache[0].dtype == torch.bfloat16 -# ── split_with_zeros ── - - +### Split tests def test_split_all_nonzero(): """Ensure NaiveCache splits in the simple case (all lens nonzero).""" t = torch.randn(15, NUM_KV_HEADS, HEAD_DIM) @@ -162,6 +157,39 @@ def test_split_preserves_dtype(): assert parts[1].dtype == torch.bfloat16 +### from_object tests (for kv cache transfer) +def test_from_object_passthrough(): + """Ensure a NaiveCache input is returned as is.""" + cache = _make_cache(NUM_LAYERS, seq_len=5) + assert NaiveCache.from_object(cache) is cache + + +def test_from_object_converts_simple_namespace(): + """Ensure SimpleNamespace with list-based caches converts to NaiveCache.""" + keys = [torch.randn(5, NUM_KV_HEADS, HEAD_DIM) for _ in range(NUM_LAYERS)] + values = [torch.randn(5, NUM_KV_HEADS, HEAD_DIM) for _ in range(NUM_LAYERS)] + ns = SimpleNamespace(key_cache=keys, value_cache=values) + + cache = NaiveCache.from_object(ns) + + assert isinstance(cache, NaiveCache) + assert cache.num_layers == NUM_LAYERS + for i in range(NUM_LAYERS): + assert torch.equal(cache.key_cache[i], keys[i]) + assert torch.equal(cache.value_cache[i], values[i]) + + +def test_from_object_mismatched_lengths_raises(): + """Ensure mismatched key/value cache lengths raise due to strict=True in zip.""" + keys = [torch.randn(5, NUM_KV_HEADS, HEAD_DIM) for _ in range(2)] + values = [torch.randn(5, NUM_KV_HEADS, HEAD_DIM) for _ in range(3)] + ns = SimpleNamespace(key_cache=keys, value_cache=values) + + with pytest.raises(ValueError): + NaiveCache.from_object(ns) + + +### End to end test for split / merge def test_round_trip_two_populated(): """Roundtrip test for merging and resplitting two simple caches.""" c0 = _make_cache(NUM_LAYERS, seq_len=5, seed=0) From c78366a92dc19f7c47308b3ec3ed1bc91d3f2ad5 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 3 Jun 2026 07:41:13 +0000 Subject: [PATCH 08/12] batch vae passes Signed-off-by: Alex Brooks --- .../models/bagel/bagel_transformer.py | 81 ++++++++++++------- .../diffusion/models/bagel/pipeline_bagel.py | 3 + 2 files changed, 56 insertions(+), 28 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 6afba92452e..efabe05924a 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -323,6 +323,25 @@ def seq_lens(self): else: return 0 + @classmethod + def from_object(cls, obj) -> "NaiveCache": + """Convert a duck-typed cache (e.g., SimpleNamespace from KV transfer) + to NaiveCache; in the future, we should find a better way to handle this, + e.g., a model agnostic abstraction for key cache transfer instead of having + this cache live in bagel. + + NOTE: If a NaiveCache is provided, the object is just returned. Otherwise, + we enumerate over the key/value cache values and map layer indices to the + corresponding tensors. + """ + if isinstance(obj, cls): + return obj + cache = cls(len(obj.key_cache)) + for i, (k, v) in enumerate(zip(obj.key_cache, obj.value_cache, strict=True)): + cache.key_cache[i] = k + cache.value_cache[i] = v + return cache + @staticmethod def merge(caches: Sequence["NaiveCache"]) -> "NaiveCache": """Merge per-branch NaiveCaches into one for batched attention; @@ -1790,13 +1809,11 @@ def generate_image( frame_condition_token_indexes = frame_condition_token_indexes.to(x_t.device).long() pinned_x_t = x_t[frame_condition_token_indexes].clone() - # TODO: Re-enable with new reference pixels in Bagel tests - # # Use num_timesteps + 1 sample points so we get `num_timesteps` denoise - # # steps after dropping the terminal t=0 (which has no dt). Upstream - # # Lance / BAGEL both use this convention; without the +1 we silently - # # run one fewer denoise iteration than the user asked for. - # timesteps = torch.linspace(1, 0, num_timesteps + 1, device=x_t.device) - timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device) + # Use num_timesteps + 1 sample points so we get `num_timesteps` denoise + # steps after dropping the terminal t=0 (which has no dt). Upstream + # Lance / BAGEL both use this convention; without the +1 we silently + # run one fewer denoise iteration than the user asked for. + timesteps = torch.linspace(1, 0, num_timesteps + 1, device=x_t.device) timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) dts = timesteps[:-1] - timesteps[1:] timesteps = timesteps[:-1] @@ -2369,28 +2386,36 @@ def forward( cfg_img_v_t = None if use_cfg and cfg_branch_pids is not None and cfg_branch_caches is not None: - # Sequential per-branch CFG forwards (matches upstream lance.py). - # The previous batched path concatenated cond + cfg into one LLM - # forward, but the block-diagonal attention mask was lost when - # PR #3728 dropped flash_attn_varlen, so branches leaked into - # each other. Running each branch through its own forward is - # numerically identical to upstream. - def _run_branch(branch_pkv, branch_pids): - out = self.language_model.forward( - packed_query_sequence=packed_sequence, - query_lens=packed_seqlens, - packed_query_position_ids=branch_pids, - past_key_values=branch_pkv, - update_past_key_values=False, - is_causal=False, - **extra_inputs, - ) - return self.llm2vae(out.packed_query_sequence)[packed_vae_token_indexes] + num_branches = len(cfg_branch_pids) + seq_len = int(packed_seqlens.sum()) + + batched_sequence = packed_sequence.repeat(num_branches, 1) + batched_vae_indexes = torch.cat([packed_vae_token_indexes + i * seq_len for i in range(num_branches)]) + batched_position_ids = torch.cat(cfg_branch_pids) + batched_seqlens = packed_seqlens.repeat(num_branches) + merged_cache = NaiveCache.merge(cfg_branch_caches) + + if self.use_moe: + batched_text_indices = torch.cat([packed_text_indexes + i * seq_len for i in range(num_branches)]) + extra_inputs["packed_vae_token_indexes"] = batched_vae_indexes + extra_inputs["packed_text_indexes"] = batched_text_indices + + output = self.language_model.forward( + packed_query_sequence=batched_sequence, + query_lens=batched_seqlens, + packed_query_position_ids=batched_position_ids, + past_key_values=merged_cache, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) - v_t = _run_branch(cfg_branch_caches[0], cfg_branch_pids[0]) - cfg_text_v_t = _run_branch(cfg_branch_caches[1], cfg_branch_pids[1]) - if cfg_img_scale > 1.0 and len(cfg_branch_caches) > 2: - cfg_img_v_t = _run_branch(cfg_branch_caches[2], cfg_branch_pids[2]) + all_vae_v_t = self.llm2vae(output.packed_query_sequence)[batched_vae_indexes] + vae_per_branch = packed_vae_token_indexes.shape[0] + branch_v_ts = all_vae_v_t.split(vae_per_branch) + v_t = branch_v_ts[0] + cfg_text_v_t = branch_v_ts[1] + cfg_img_v_t = branch_v_ts[2] if len(branch_v_ts) > 2 else None else: # Single forward (no CFG or outside cfg_interval). output = self.language_model.forward( diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 9f950970ce7..e8ed88f096d 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -395,6 +395,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: injected_kv = req.sampling_params.past_key_values if injected_kv is not None: logger.info("Using injected KV Cache (direct)") + injected_kv = NaiveCache.from_object(injected_kv) gen_context["past_key_values"] = injected_kv seq_len = injected_kv.key_cache[0].shape[0] gen_context["kv_lens"] = [seq_len] @@ -431,6 +432,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: ) if cfg_text_kv is not None: + cfg_text_kv = NaiveCache.from_object(cfg_text_kv) cfg_text_seq_len = cfg_text_kv.key_cache[0].shape[0] cfg_text_context["past_key_values"] = cfg_text_kv cfg_text_context["kv_lens"] = [cfg_text_seq_len] @@ -458,6 +460,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: else: cfg_img_context["ropes"] = [cfg_img_seq_len] else: + cfg_img_kv = NaiveCache.from_object(cfg_img_kv) cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0] cfg_img_context["past_key_values"] = cfg_img_kv cfg_img_context["kv_lens"] = [cfg_img_seq_len] From 78c83b09302741c7185d256eabf3bf7320baa2fa Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 3 Jun 2026 08:00:12 +0000 Subject: [PATCH 09/12] update pixel refs Signed-off-by: Alex Brooks --- .../test_bagel_mooncake_connector.py | 20 +++++----- .../test_bagel_shared_memory_connector.py | 40 +++++++++---------- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/tests/distributed/omni_connectors/test_bagel_mooncake_connector.py b/tests/distributed/omni_connectors/test_bagel_mooncake_connector.py index 106507114b1..f3aa36b3ae6 100644 --- a/tests/distributed/omni_connectors/test_bagel_mooncake_connector.py +++ b/tests/distributed/omni_connectors/test_bagel_mooncake_connector.py @@ -35,16 +35,16 @@ # "Generated with seed=52, num_inference_steps=14, # prompt='A cute cat'" REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (115, 113, 94)}, - {"position": (400, 50), "rgb": (159, 160, 144)}, - {"position": (700, 100), "rgb": (164, 151, 123)}, - {"position": (150, 400), "rgb": (120, 121, 107)}, - {"position": (512, 512), "rgb": (165, 133, 127)}, - {"position": (700, 400), "rgb": (217, 130, 66)}, - {"position": (100, 700), "rgb": (191, 168, 152)}, - {"position": (400, 700), "rgb": (130, 96, 77)}, - {"position": (700, 700), "rgb": (247, 203, 140)}, - {"position": (256, 256), "rgb": (167, 156, 150)}, + {"position": (100, 100), "rgb": (64, 45, 35)}, + {"position": (400, 50), "rgb": (81, 58, 44)}, + {"position": (700, 100), "rgb": (106, 77, 50)}, + {"position": (150, 400), "rgb": (67, 47, 36)}, + {"position": (512, 512), "rgb": (165, 155, 140)}, + {"position": (700, 400), "rgb": (137, 101, 64)}, + {"position": (100, 700), "rgb": (51, 42, 37)}, + {"position": (400, 700), "rgb": (217, 214, 203)}, + {"position": (700, 700), "rgb": (91, 55, 28)}, + {"position": (256, 256), "rgb": (76, 53, 41)}, ] # Maximum allowed difference per color channel diff --git a/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py b/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py index eb5a79d6aff..7efb6f094d6 100644 --- a/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py +++ b/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py @@ -34,36 +34,34 @@ # prompt='Change the grass color to red', # input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (156, 172, 217)}, + {"position": (100, 100), "rgb": (155, 172, 216)}, {"position": (400, 50), "rgb": (105, 144, 217)}, - {"position": (700, 100), "rgb": (118, 159, 232)}, - {"position": (150, 400), "rgb": (180, 22, 52)}, - {"position": (512, 336), "rgb": (221, 211, 194)}, - {"position": (700, 400), "rgb": (192, 10, 46)}, - {"position": (100, 600), "rgb": (102, 12, 22)}, - {"position": (400, 600), "rgb": (161, 28, 47)}, - {"position": (700, 600), "rgb": (100, 87, 94)}, - {"position": (256, 256), "rgb": (181, 201, 221)}, + {"position": (700, 100), "rgb": (119, 160, 231)}, + {"position": (150, 400), "rgb": (181, 9, 53)}, + {"position": (512, 336), "rgb": (191, 190, 157)}, + {"position": (700, 400), "rgb": (190, 17, 50)}, + {"position": (100, 600), "rgb": (96, 0, 11)}, + {"position": (400, 600), "rgb": (144, 16, 39)}, + {"position": (700, 600), "rgb": (101, 86, 91)}, + {"position": (256, 256), "rgb": (181, 202, 221)}, ] - # text2img reference pixels (aligned with offline `bagel/end2end.py` text2img, 15 steps) # "Generated with seed=52, num_inference_steps=14, # prompt='A cute cat'" TEXT2IMG_REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (115, 113, 94)}, - {"position": (400, 50), "rgb": (159, 160, 144)}, - {"position": (700, 100), "rgb": (164, 151, 123)}, - {"position": (150, 400), "rgb": (120, 121, 107)}, - {"position": (512, 512), "rgb": (165, 133, 127)}, - {"position": (700, 400), "rgb": (217, 130, 66)}, - {"position": (100, 700), "rgb": (191, 168, 152)}, - {"position": (400, 700), "rgb": (130, 96, 77)}, - {"position": (700, 700), "rgb": (247, 203, 140)}, - {"position": (256, 256), "rgb": (167, 156, 150)}, + {"position": (100, 100), "rgb": (64, 45, 35)}, + {"position": (400, 50), "rgb": (81, 58, 44)}, + {"position": (700, 100), "rgb": (106, 77, 50)}, + {"position": (150, 400), "rgb": (67, 47, 36)}, + {"position": (512, 512), "rgb": (165, 155, 140)}, + {"position": (700, 400), "rgb": (137, 101, 64)}, + {"position": (100, 700), "rgb": (51, 42, 37)}, + {"position": (400, 700), "rgb": (217, 214, 203)}, + {"position": (700, 700), "rgb": (91, 55, 28)}, + {"position": (256, 256), "rgb": (76, 53, 41)}, ] - PIXEL_TOLERANCE = 10 TEXT2IMG_PIXEL_TOLERANCE = 5 From 5dfce6899d04d84be963574558272081f6860bd7 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 3 Jun 2026 08:04:11 +0000 Subject: [PATCH 10/12] minor Signed-off-by: Alex Brooks --- vllm_omni/diffusion/models/bagel/bagel_transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index efabe05924a..3ceb4056a04 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -1574,7 +1574,6 @@ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_tok newlens.append(curr_kvlen + num_img_tokens + 2) new_rope.append(curr_position_id + 1) - # TODO - 棄用 (deprecated) kwargs should just be removed here so we do not need to pop them later generation_input = { "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), From 47aeee02ce34911606193c2079793c06a2687ef6 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 3 Jun 2026 08:06:51 +0000 Subject: [PATCH 11/12] remove outdated comments Signed-off-by: Alex Brooks --- vllm_omni/diffusion/models/bagel/bagel_transformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 3ceb4056a04..3d6966fa9d2 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -582,9 +582,7 @@ def _forward_gen( if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: cache_k = past_key_values.key_cache[self.layer_idx] cache_v = past_key_values.value_cache[self.layer_idx] - ctx_k = torch.cat( - [cache_k, text_k], dim=0 - ) # we are catting a [20,4,128] to [6, 4, 128] -> [26, 4, 128] but not sure why we have 20 + ctx_k = torch.cat([cache_k, text_k], dim=0) ctx_v = torch.cat([cache_v, text_v], dim=0) else: ctx_k = text_k @@ -989,7 +987,7 @@ def forward( if mode == "gen": assert packed_vae_token_indexes is not None assert packed_text_indexes is not None - extra_inputs.update( # we have 3 <4096> tokens + extra_inputs.update( packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, ) From 4666ee3d776ebb2cedb927ab5142b64a89c69819 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 4 Jun 2026 14:48:37 +0000 Subject: [PATCH 12/12] rebase pixels Signed-off-by: Alex Brooks --- .../test_bagel_mooncake_connector.py | 20 ++++----- .../test_bagel_shared_memory_connector.py | 44 ++++++++++--------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/tests/distributed/omni_connectors/test_bagel_mooncake_connector.py b/tests/distributed/omni_connectors/test_bagel_mooncake_connector.py index f3aa36b3ae6..106507114b1 100644 --- a/tests/distributed/omni_connectors/test_bagel_mooncake_connector.py +++ b/tests/distributed/omni_connectors/test_bagel_mooncake_connector.py @@ -35,16 +35,16 @@ # "Generated with seed=52, num_inference_steps=14, # prompt='A cute cat'" REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (64, 45, 35)}, - {"position": (400, 50), "rgb": (81, 58, 44)}, - {"position": (700, 100), "rgb": (106, 77, 50)}, - {"position": (150, 400), "rgb": (67, 47, 36)}, - {"position": (512, 512), "rgb": (165, 155, 140)}, - {"position": (700, 400), "rgb": (137, 101, 64)}, - {"position": (100, 700), "rgb": (51, 42, 37)}, - {"position": (400, 700), "rgb": (217, 214, 203)}, - {"position": (700, 700), "rgb": (91, 55, 28)}, - {"position": (256, 256), "rgb": (76, 53, 41)}, + {"position": (100, 100), "rgb": (115, 113, 94)}, + {"position": (400, 50), "rgb": (159, 160, 144)}, + {"position": (700, 100), "rgb": (164, 151, 123)}, + {"position": (150, 400), "rgb": (120, 121, 107)}, + {"position": (512, 512), "rgb": (165, 133, 127)}, + {"position": (700, 400), "rgb": (217, 130, 66)}, + {"position": (100, 700), "rgb": (191, 168, 152)}, + {"position": (400, 700), "rgb": (130, 96, 77)}, + {"position": (700, 700), "rgb": (247, 203, 140)}, + {"position": (256, 256), "rgb": (167, 156, 150)}, ] # Maximum allowed difference per color channel diff --git a/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py b/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py index 7efb6f094d6..b6b67b2027b 100644 --- a/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py +++ b/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py @@ -7,7 +7,7 @@ - img2img: validates output vs reference pixels within a ±10 tolerance. - text2img: validates output vs reference pixels within a ±5 tolerance (equivalent to `examples/offline_inference/bagel/end2end.py` with - text2img modality and 15 steps). + text2img modality and 14 steps). """ import os @@ -34,34 +34,36 @@ # prompt='Change the grass color to red', # input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (155, 172, 216)}, + {"position": (100, 100), "rgb": (156, 172, 217)}, {"position": (400, 50), "rgb": (105, 144, 217)}, - {"position": (700, 100), "rgb": (119, 160, 231)}, - {"position": (150, 400), "rgb": (181, 9, 53)}, - {"position": (512, 336), "rgb": (191, 190, 157)}, - {"position": (700, 400), "rgb": (190, 17, 50)}, - {"position": (100, 600), "rgb": (96, 0, 11)}, - {"position": (400, 600), "rgb": (144, 16, 39)}, - {"position": (700, 600), "rgb": (101, 86, 91)}, - {"position": (256, 256), "rgb": (181, 202, 221)}, + {"position": (700, 100), "rgb": (118, 159, 232)}, + {"position": (150, 400), "rgb": (180, 22, 52)}, + {"position": (512, 336), "rgb": (221, 211, 194)}, + {"position": (700, 400), "rgb": (192, 10, 46)}, + {"position": (100, 600), "rgb": (102, 12, 22)}, + {"position": (400, 600), "rgb": (161, 28, 47)}, + {"position": (700, 600), "rgb": (100, 87, 94)}, + {"position": (256, 256), "rgb": (181, 201, 221)}, ] -# text2img reference pixels (aligned with offline `bagel/end2end.py` text2img, 15 steps) + +# text2img reference pixels (aligned with offline `bagel/end2end.py` text2img, 14 steps) # "Generated with seed=52, num_inference_steps=14, # prompt='A cute cat'" TEXT2IMG_REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (64, 45, 35)}, - {"position": (400, 50), "rgb": (81, 58, 44)}, - {"position": (700, 100), "rgb": (106, 77, 50)}, - {"position": (150, 400), "rgb": (67, 47, 36)}, - {"position": (512, 512), "rgb": (165, 155, 140)}, - {"position": (700, 400), "rgb": (137, 101, 64)}, - {"position": (100, 700), "rgb": (51, 42, 37)}, - {"position": (400, 700), "rgb": (217, 214, 203)}, - {"position": (700, 700), "rgb": (91, 55, 28)}, - {"position": (256, 256), "rgb": (76, 53, 41)}, + {"position": (100, 100), "rgb": (115, 113, 94)}, + {"position": (400, 50), "rgb": (159, 160, 144)}, + {"position": (700, 100), "rgb": (164, 151, 123)}, + {"position": (150, 400), "rgb": (120, 121, 107)}, + {"position": (512, 512), "rgb": (165, 133, 127)}, + {"position": (700, 400), "rgb": (217, 130, 66)}, + {"position": (100, 700), "rgb": (191, 168, 152)}, + {"position": (400, 700), "rgb": (130, 96, 77)}, + {"position": (700, 700), "rgb": (247, 203, 140)}, + {"position": (256, 256), "rgb": (167, 156, 150)}, ] + PIXEL_TOLERANCE = 10 TEXT2IMG_PIXEL_TOLERANCE = 5