diff --git a/tests/diffusion/models/bagel/test_naive_cache.py b/tests/diffusion/models/bagel/test_naive_cache.py new file mode 100644 index 00000000000..26ad0daeede --- /dev/null +++ b/tests/diffusion/models/bagel/test_naive_cache.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for NaiveCache merge/split logic used in batched CFG.""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.diffusion.models.bagel.bagel_transformer import NaiveCache + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +NUM_LAYERS = 2 +NUM_KV_HEADS = 4 +HEAD_DIM = 8 + + +def _make_cache(num_layers, seq_len, num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM, seed=0): + """Create a NaiveCache with deterministic random data. seq_len=0 returns an empty cache.""" + gen = torch.Generator().manual_seed(seed) + cache = NaiveCache(num_layers) + if seq_len == 0: + return cache + for layer in range(num_layers): + cache.key_cache[layer] = torch.randn(seq_len, num_kv_heads, head_dim, generator=gen) + cache.value_cache[layer] = torch.randn(seq_len, num_kv_heads, head_dim, generator=gen) + return cache + + +def test_init_creates_none_entries(): + """Ensure the NaiveCache is initialized with None values per layer.""" + cache = NaiveCache(NUM_LAYERS) + assert cache.num_layers == NUM_LAYERS + for layer in range(NUM_LAYERS): + assert cache.key_cache[layer] is None + assert cache.value_cache[layer] is None + + +@pytest.mark.parametrize("seq_len", [0, 10]) +def test_seq_lens_empty(seq_len): + """Ensure that by default, we have 0 seq lens.""" + cache = _make_cache(NUM_LAYERS, seq_len=seq_len) + assert cache.seq_lens == seq_len + assert cache.num_layers == NUM_LAYERS + + +### Merge tests +def test_merge_two_equal_length(): + """Ensure that we can merge two NaiveCaches that are identically shaped.""" + c0 = _make_cache(NUM_LAYERS, seq_len=5, seed=0) + c1 = _make_cache(NUM_LAYERS, seq_len=5, seed=1) + merged = NaiveCache.merge([c0, c1]) + + assert merged.key_values_lens == [5, 5] + for layer in range(NUM_LAYERS): + assert merged.key_cache[layer].shape[0] == 10 + # the merged cache will just have keys and values per layer concatenated + assert torch.equal(merged.key_cache[layer][:5], c0.key_cache[layer]) + assert torch.equal(merged.key_cache[layer][5:], c1.key_cache[layer]) + assert torch.equal(merged.value_cache[layer][:5], c0.value_cache[layer]) + assert torch.equal(merged.value_cache[layer][5:], c1.value_cache[layer]) + + +def test_merge_three_zero_len_cache(): + """Ensure we handle zero length cache correctly in merge.""" + # NOTE: This is relevant for text_cfg in Bagel, which has a len of 0 by default + gen_cache = _make_cache(NUM_LAYERS, seq_len=10, seed=0) + text_cfg_cache = _make_cache(NUM_LAYERS, seq_len=0) + img_cfg_cache = _make_cache(NUM_LAYERS, seq_len=7, seed=2) + merged = NaiveCache.merge([gen_cache, text_cfg_cache, img_cfg_cache]) + + assert merged.key_values_lens == [10, 0, 7] + for layer in range(NUM_LAYERS): + assert merged.key_cache[layer].shape[0] == 17 + assert torch.equal(merged.key_cache[layer][:10], gen_cache.key_cache[layer]) + assert torch.equal(merged.key_cache[layer][10:], img_cfg_cache.key_cache[layer]) + + +def test_merge_all_empty(): + """Ensure that merging empty caches is well defined.""" + caches = [_make_cache(NUM_LAYERS, seq_len=0) for _ in range(3)] + merged = NaiveCache.merge(caches) + + assert merged.key_values_lens == [0, 0, 0] + for layer in range(NUM_LAYERS): + assert merged.key_cache[layer] is None + assert merged.value_cache[layer] is None + + +def test_merge_single_cache(): + """Ensure merging one cache returns an identical cache.""" + c = _make_cache(NUM_LAYERS, seq_len=8, seed=42) + merged = NaiveCache.merge([c]) + + assert merged.key_values_lens == [8] + for layer in range(NUM_LAYERS): + assert torch.equal(merged.key_cache[layer], c.key_cache[layer]) + + +def test_merge_preserves_dtype(): + """Ensure merging doesn't modify dtypes.""" + cache = NaiveCache(NUM_LAYERS) + for layer in range(NUM_LAYERS): + cache.key_cache[layer] = torch.randn(5, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16) + cache.value_cache[layer] = torch.randn(5, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16) + + merged = NaiveCache.merge([cache]) + assert merged.key_cache[0].dtype == torch.bfloat16 + assert merged.value_cache[0].dtype == torch.bfloat16 + + +### Split tests +def test_split_all_nonzero(): + """Ensure NaiveCache splits in the simple case (all lens nonzero).""" + t = torch.randn(15, NUM_KV_HEADS, HEAD_DIM) + parts = NaiveCache.split_with_zeros(t, [5, 4, 6]) + + assert len(parts) == 3 + assert all(p is not None for p in parts) + assert parts[0].shape[0] == 5 + assert parts[1].shape[0] == 4 + assert parts[2].shape[0] == 6 + assert torch.equal(torch.cat(parts), t) + + +def test_split_with_zero(): + """Ensure NaiveCache split handles zero length correctly (used in Bagel).""" + t = torch.randn(17, NUM_KV_HEADS, HEAD_DIM) + parts = NaiveCache.split_with_zeros(t, [10, 0, 7]) + + assert parts[0].shape[0] == 10 + assert parts[1] is None + assert parts[2].shape[0] == 7 + assert torch.equal(torch.cat([parts[0], parts[2]]), t) + + +def test_split_wrong_sum_raises(): + """Ensure NaiveCache raises if splits don't match the sum of dims on axis 0.""" + t = torch.randn(10, NUM_KV_HEADS, HEAD_DIM) + with pytest.raises(ValueError, match="dim 0"): + NaiveCache.split_with_zeros(t, [5, 3]) + + +def test_split_negative_length_raises(): + """Ensure NaiveCache raises if splits have any negative values.""" + t = torch.randn(10, NUM_KV_HEADS, HEAD_DIM) + with pytest.raises(ValueError, match="greater than or equal to zero"): + NaiveCache.split_with_zeros(t, [5, -1, 6]) + + +def test_split_preserves_dtype(): + """Ensure NaiveCache split preserves dtype.""" + t = torch.randn(10, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16) + parts = NaiveCache.split_with_zeros(t, [4, 6]) + assert parts[0].dtype == torch.bfloat16 + assert parts[1].dtype == torch.bfloat16 + + +### from_object tests (for kv cache transfer) +def test_from_object_passthrough(): + """Ensure a NaiveCache input is returned as is.""" + cache = _make_cache(NUM_LAYERS, seq_len=5) + assert NaiveCache.from_object(cache) is cache + + +def test_from_object_converts_simple_namespace(): + """Ensure SimpleNamespace with list-based caches converts to NaiveCache.""" + keys = [torch.randn(5, NUM_KV_HEADS, HEAD_DIM) for _ in range(NUM_LAYERS)] + values = [torch.randn(5, NUM_KV_HEADS, HEAD_DIM) for _ in range(NUM_LAYERS)] + ns = SimpleNamespace(key_cache=keys, value_cache=values) + + cache = NaiveCache.from_object(ns) + + assert isinstance(cache, NaiveCache) + assert cache.num_layers == NUM_LAYERS + for i in range(NUM_LAYERS): + assert torch.equal(cache.key_cache[i], keys[i]) + assert torch.equal(cache.value_cache[i], values[i]) + + +def test_from_object_mismatched_lengths_raises(): + """Ensure mismatched key/value cache lengths raise due to strict=True in zip.""" + keys = [torch.randn(5, NUM_KV_HEADS, HEAD_DIM) for _ in range(2)] + values = [torch.randn(5, NUM_KV_HEADS, HEAD_DIM) for _ in range(3)] + ns = SimpleNamespace(key_cache=keys, value_cache=values) + + with pytest.raises(ValueError): + NaiveCache.from_object(ns) + + +### End to end test for split / merge +def test_round_trip_two_populated(): + """Roundtrip test for merging and resplitting two simple caches.""" + c0 = _make_cache(NUM_LAYERS, seq_len=5, seed=0) + c1 = _make_cache(NUM_LAYERS, seq_len=8, seed=1) + merged = NaiveCache.merge([c0, c1]) + + for layer in range(NUM_LAYERS): + k_parts = NaiveCache.split_with_zeros(merged.key_cache[layer], merged.key_values_lens) + v_parts = NaiveCache.split_with_zeros(merged.value_cache[layer], merged.key_values_lens) + assert torch.equal(k_parts[0], c0.key_cache[layer]) + assert torch.equal(k_parts[1], c1.key_cache[layer]) + assert torch.equal(v_parts[0], c0.value_cache[layer]) + assert torch.equal(v_parts[1], c1.value_cache[layer]) + + +def test_round_trip_three_branches_with_zero_cfg(): + """Roundtrip test with a zero entry (i.e., same as Bagel's gen/text_cfg/img_cfg case).""" + gen_cache = _make_cache(NUM_LAYERS, seq_len=10, seed=0) + text_cfg_cache = _make_cache(NUM_LAYERS, seq_len=0) + img_cfg_cache = _make_cache(NUM_LAYERS, seq_len=7, seed=2) + merged = NaiveCache.merge([gen_cache, text_cfg_cache, img_cfg_cache]) + + for layer in range(NUM_LAYERS): + k_parts = NaiveCache.split_with_zeros(merged.key_cache[layer], merged.key_values_lens) + v_parts = NaiveCache.split_with_zeros(merged.value_cache[layer], merged.key_values_lens) + assert torch.equal(k_parts[0], gen_cache.key_cache[layer]) + assert k_parts[1] is None + assert torch.equal(k_parts[2], img_cfg_cache.key_cache[layer]) + assert torch.equal(v_parts[0], gen_cache.value_cache[layer]) + assert v_parts[1] is None + assert torch.equal(v_parts[2], img_cfg_cache.value_cache[layer]) diff --git a/tests/diffusion/utils/__init__.py b/tests/diffusion/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/utils/test_kv_utils.py b/tests/diffusion/utils/test_kv_utils.py new file mode 100644 index 00000000000..6b523b92033 --- /dev/null +++ b/tests/diffusion/utils/test_kv_utils.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for kv utils.""" + +import pytest +import torch + +from vllm_omni.diffusion.utils.kv_utils import left_pad_stack + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_uniform_lengths(): + """Ensure uniform 3D tensors correctly stack to 4D and have no mask.""" + tensors = [torch.randn(10, 4, 128) for _ in range(3)] + stacked, mask = left_pad_stack(tensors) + assert stacked.shape == (3, 10, 4, 128) + assert mask is None + for i in range(3): + assert torch.equal(stacked[i], tensors[i]) + + +def test_variable_lengths(): + """Ensure variable 3D tensors correctly stack to 4D with a mask.""" + t1 = torch.ones(5, 2, 4) + t2 = torch.ones(8, 2, 4) + t3 = torch.ones(3, 2, 4) + stacked, mask = left_pad_stack([t1, t2, t3]) + + assert stacked.shape == (3, 8, 2, 4) + assert mask is not None + assert mask.shape == (3, 8) + # Ensure summing over dim 1 gives our seq lens back + assert mask.sum(dim=1).tolist() == [5, 8, 3] + + +def test_single_tensor_is_4d(): + """Ensure a single 3D tensor is expanded to 4D.""" + t = torch.randn(7, 4, 128) + stacked, mask = left_pad_stack([t]) + assert stacked.shape == (1, 7, 4, 128) + assert mask is None + assert torch.equal(stacked[0], t) + + +def test_preserves_device_and_dtype(): + """Ensure device/dtype is preserved.""" + t1 = torch.randn(3, 2, dtype=torch.bfloat16) + t2 = torch.randn(5, 2, dtype=torch.bfloat16) + stacked, mask = left_pad_stack([t1, t2]) + assert stacked.dtype == torch.bfloat16 + assert stacked.device == t1.device + assert mask.device == t1.device + + +def test_mismatched_trailing_shapes_raises(): + """Ensure that mismatched dims outside of 0 explodes.""" + t1 = torch.randn(5, 4, 128) + t2 = torch.randn(5, 8, 128) + with pytest.raises(ValueError): + left_pad_stack([t1, t2]) + + +def test_mismatched_ndim_raises(): + """Ensure that mismatched ndims explodes.""" + t1 = torch.randn(5, 4) + t2 = torch.randn(5, 4, 128) + with pytest.raises(ValueError): + left_pad_stack([t1, t2]) + + +def test_empty_list_raises(): + """Ensure that tensors must be nonempty.""" + with pytest.raises(ValueError): + left_pad_stack([]) diff --git a/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py b/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py index eb5a79d6aff..b6b67b2027b 100644 --- a/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py +++ b/tests/distributed/omni_connectors/test_bagel_shared_memory_connector.py @@ -7,7 +7,7 @@ - img2img: validates output vs reference pixels within a ±10 tolerance. - text2img: validates output vs reference pixels within a ±5 tolerance (equivalent to `examples/offline_inference/bagel/end2end.py` with - text2img modality and 15 steps). + text2img modality and 14 steps). """ import os @@ -47,7 +47,7 @@ ] -# text2img reference pixels (aligned with offline `bagel/end2end.py` text2img, 15 steps) +# text2img reference pixels (aligned with offline `bagel/end2end.py` text2img, 14 steps) # "Generated with seed=52, num_inference_steps=14, # prompt='A cute cat'" TEXT2IMG_REFERENCE_PIXELS = [ diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 4231d4cc638..3d6966fa9d2 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -7,7 +7,7 @@ # Original file was released under Apache-2.0, with the full license text # available at https://github.com/huggingface/transformers/blob/main/LICENSE. -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from dataclasses import dataclass import numpy as np @@ -46,6 +46,7 @@ ) from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.utils.kv_utils import left_pad_stack from vllm_omni.model_executor.layers.timestep_embedding import timestep_embedding @@ -306,6 +307,10 @@ class NaiveCache: def __init__(self, num_layers): self.key_cache = {k: None for k in range(num_layers)} self.value_cache = {k: None for k in range(num_layers)} + # Track kv_lens; we need this because we pack the forward passes + # for CFG into a single forward call and the kv length may be different, + # e.g., due to 0 kvs for text_cfg path and nonzero for others + self.key_values_lens: list[int] | None = None @property def num_layers(self): @@ -318,6 +323,76 @@ def seq_lens(self): else: return 0 + @classmethod + def from_object(cls, obj) -> "NaiveCache": + """Convert a duck-typed cache (e.g., SimpleNamespace from KV transfer) + to NaiveCache; in the future, we should find a better way to handle this, + e.g., a model agnostic abstraction for key cache transfer instead of having + this cache live in bagel. + + NOTE: If a NaiveCache is provided, the object is just returned. Otherwise, + we enumerate over the key/value cache values and map layer indices to the + corresponding tensors. + """ + if isinstance(obj, cls): + return obj + cache = cls(len(obj.key_cache)) + for i, (k, v) in enumerate(zip(obj.key_cache, obj.value_cache, strict=True)): + cache.key_cache[i] = k + cache.value_cache[i] = v + return cache + + @staticmethod + def merge(caches: Sequence["NaiveCache"]) -> "NaiveCache": + """Merge per-branch NaiveCaches into one for batched attention; + this lets us do the forward passes for CFG in one batched pass, + although it's worth noting that this is currently only used + for single request. We need this so that gen mode knows the + respective kv lengths, and can split things back out as needed. + """ + num_layers = caches[0].num_layers + merged = NaiveCache(num_layers) + lens = [c.seq_lens for c in caches] + merged.key_values_lens = lens + + nonempty = [c for c in caches if c.key_cache[0] is not None] + if not nonempty: + return merged + + for layer in range(num_layers): + merged.key_cache[layer] = torch.cat([c.key_cache[layer] for c in nonempty], dim=0) + merged.value_cache[layer] = torch.cat([c.value_cache[layer] for c in nonempty], dim=0) + + return merged + + @staticmethod + def split_with_zeros( + tensor: torch.Tensor, + lengths: Sequence[int], + ) -> list[torch.Tensor | None]: + """Split tensor by lengths, which may include 0 entries, e.g., for splitting cfg + branches out, since text_cfg may have 0 kv length. + + 0 lengths will be replaced with None in the returned list. + """ + # Ensure that the lengths are all nonzero and sum to the first dim of our tensor + if not all(isinstance(ln, int) and ln >= 0 for ln in lengths): + raise ValueError("split lengths must be greater than or equal to zero") + + expected = sum(ln for ln in lengths if ln > 0) + if tensor.shape[0] != expected: + raise ValueError(f"tensor dim 0 ({tensor.shape[0]}) != sum of nonzero lengths ({expected})") + + result: list[torch.Tensor | None] = [] + offset = 0 + for ln in lengths: + if ln > 0: + result.append(tensor[offset : offset + ln]) + offset += ln + else: + result.append(None) + return result + @dataclass class BaseNavitOutputWithPast(ModelOutput): @@ -431,6 +506,7 @@ def _is_sp_active(self) -> bool: def _forward_gen( self, packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, past_key_values: NaiveCache | None, packed_vae_token_indexes: torch.Tensor, @@ -455,6 +531,7 @@ def _forward_gen( Currently we shouldn't need it in the model, and it would be ideal to handle packing/batching etc in a more model agnostic way. """ + cache_k = cache_v = None packed_query_sequence = packed_query_sequence.to(torch.bfloat16) packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] @@ -527,15 +604,50 @@ def _forward_gen( ), ) else: - q = torch.cat([text_q, vae_q], dim=0).unsqueeze(0) - k = torch.cat([ctx_k, vae_k], dim=0).unsqueeze(0) - v = torch.cat([ctx_v, vae_v], dim=0).unsqueeze(0) - attn_out = self.attn_noncausal(q, k, v) + num_branches = len(query_lens) + text_per_branch = text_q.shape[0] // num_branches + vae_per_branch = vae_q.shape[0] // num_branches + + text_q_parts = text_q.split([text_per_branch] * num_branches) + vae_q_parts = vae_q.split([vae_per_branch] * num_branches) + text_k_parts = text_k.split([text_per_branch] * num_branches) + vae_k_parts = vae_k.split([vae_per_branch] * num_branches) + text_v_parts = text_v.split([text_per_branch] * num_branches) + vae_v_parts = vae_v.split([vae_per_branch] * num_branches) + + # Query lengths should not be variable since we + # just split above, so we just concat + stack to 4D + q_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_q_parts, vae_q_parts)]) + + if cache_k is not None and cache_v is not None: + kv_lens = getattr(past_key_values, "key_values_lens", None) + if kv_lens is None: + per_branch = cache_k.shape[0] // num_branches + kv_lens = [per_branch] * num_branches + + ck_per_branch = NaiveCache.split_with_zeros(cache_k, kv_lens) + cv_per_branch = NaiveCache.split_with_zeros(cache_v, kv_lens) + k_branches = [ + torch.cat([t for t in (ck_per_branch[i], text_k_parts[i], vae_k_parts[i]) if t is not None]) + for i in range(num_branches) + ] + v_branches = [ + torch.cat([t for t in (cv_per_branch[i], text_v_parts[i], vae_v_parts[i]) if t is not None]) + for i in range(num_branches) + ] + k_4d, mask = left_pad_stack(k_branches) + v_4d, _ = left_pad_stack(v_branches) + metadata = DiffusionAttentionMetadata(attn_mask=mask) if mask is not None else None + else: + k_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_k_parts, vae_k_parts)]) + v_4d = torch.stack([torch.cat([t, v]) for t, v in zip(text_v_parts, vae_v_parts)]) + metadata = None - text_len = text_q.shape[0] - attn_out = attn_out.squeeze(0) - text_attn = attn_out[:text_len].reshape(text_len, self.q_size) - vae_attn = attn_out[text_len:].reshape(-1, self.q_size) + attn_out = self.attn_noncausal(q_4d, k_4d, v_4d, metadata) + q_per_branch = int(query_lens[0]) + attn_out = attn_out.reshape(num_branches, q_per_branch, self.q_size) + text_attn = attn_out[:, :text_per_branch].reshape(-1, self.q_size) + vae_attn = attn_out[:, text_per_branch:].reshape(-1, self.q_size) # Apply output projections text_out, _ = self.o_proj(text_attn) @@ -675,6 +787,7 @@ def forward( raise ValueError("Generation model for Bagel requires non-causal attention") return self._forward_gen( packed_query_sequence=packed_query_sequence, + query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, past_key_values=past_key_values, packed_vae_token_indexes=packed_vae_token_indexes, @@ -1854,15 +1967,15 @@ def generate_image( # Each CFG branch runs its own LLM forward; we just need the # per-branch packed_position_ids and past_key_values for # ``Bagel.forward`` to dispatch through. - cfg_branches: dict | None = None + cfg_branch_pids: list[torch.Tensor] | None = None + cfg_branch_caches: list[NaiveCache] | None = None if use_cfg_text: - branches_pid = [packed_position_ids, cfg_text_packed_position_ids] - branches_cache = [past_key_values, cfg_text_past_key_values] + cfg_branch_pids = [packed_position_ids, cfg_text_packed_position_ids] + cfg_branch_caches = [past_key_values, cfg_text_past_key_values] if use_cfg_img: - branches_pid.append(cfg_img_packed_position_ids) - branches_cache.append(cfg_img_past_key_values) - cfg_branches = {"pids": branches_pid, "caches": branches_cache} + cfg_branch_pids.append(cfg_img_packed_position_ids) + cfg_branch_caches.append(cfg_img_past_key_values) if return_trajectory_latents and len(timesteps) > 0: trajectory_latents.append(x_t.clone()) @@ -1895,7 +2008,8 @@ def generate_image( cfg_renorm_type=cfg_renorm_type, cfg_text_scale=cfg_text_scale_, cfg_img_scale=cfg_img_scale_, - cfg_branches=cfg_branches, + cfg_branch_pids=cfg_branch_pids, + cfg_branch_caches=cfg_branch_caches, ) if scheduler is not None: @@ -2239,7 +2353,8 @@ def forward( cfg_renorm_type: str = "global", cfg_text_scale: float = 1.0, cfg_img_scale: float = 1.0, - cfg_branches: dict | None = None, + cfg_branch_pids: list[torch.Tensor] | None = None, + cfg_branch_caches: list[NaiveCache] | None = None, ): # Build query sequence (identical for all CFG branches) packed_text_embedding = self.language_model.forward( @@ -2267,31 +2382,37 @@ def forward( cfg_text_v_t = None cfg_img_v_t = None - if use_cfg and cfg_branches is not None: - # Sequential per-branch CFG forwards (matches upstream lance.py). - # The previous batched path concatenated cond + cfg into one LLM - # forward, but the block-diagonal attention mask was lost when - # PR #3728 dropped flash_attn_varlen, so branches leaked into - # each other. Running each branch through its own forward is - # numerically identical to upstream. - def _run_branch(branch_pkv, branch_pids): - out = self.language_model.forward( - packed_query_sequence=packed_sequence, - query_lens=packed_seqlens, - packed_query_position_ids=branch_pids, - past_key_values=branch_pkv, - update_past_key_values=False, - is_causal=False, - **extra_inputs, - ) - return self.llm2vae(out.packed_query_sequence)[packed_vae_token_indexes] - - branches_pids = cfg_branches["pids"] - branches_caches = cfg_branches["caches"] - v_t = _run_branch(branches_caches[0], branches_pids[0]) - cfg_text_v_t = _run_branch(branches_caches[1], branches_pids[1]) - if cfg_img_scale > 1.0 and len(branches_caches) > 2: - cfg_img_v_t = _run_branch(branches_caches[2], branches_pids[2]) + if use_cfg and cfg_branch_pids is not None and cfg_branch_caches is not None: + num_branches = len(cfg_branch_pids) + seq_len = int(packed_seqlens.sum()) + + batched_sequence = packed_sequence.repeat(num_branches, 1) + batched_vae_indexes = torch.cat([packed_vae_token_indexes + i * seq_len for i in range(num_branches)]) + batched_position_ids = torch.cat(cfg_branch_pids) + batched_seqlens = packed_seqlens.repeat(num_branches) + merged_cache = NaiveCache.merge(cfg_branch_caches) + + if self.use_moe: + batched_text_indices = torch.cat([packed_text_indexes + i * seq_len for i in range(num_branches)]) + extra_inputs["packed_vae_token_indexes"] = batched_vae_indexes + extra_inputs["packed_text_indexes"] = batched_text_indices + + output = self.language_model.forward( + packed_query_sequence=batched_sequence, + query_lens=batched_seqlens, + packed_query_position_ids=batched_position_ids, + past_key_values=merged_cache, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + + all_vae_v_t = self.llm2vae(output.packed_query_sequence)[batched_vae_indexes] + vae_per_branch = packed_vae_token_indexes.shape[0] + branch_v_ts = all_vae_v_t.split(vae_per_branch) + v_t = branch_v_ts[0] + cfg_text_v_t = branch_v_ts[1] + cfg_img_v_t = branch_v_ts[2] if len(branch_v_ts) > 2 else None else: # Single forward (no CFG or outside cfg_interval). output = self.language_model.forward( diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 56a5b0dcc40..e8ed88f096d 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -395,6 +395,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: injected_kv = req.sampling_params.past_key_values if injected_kv is not None: logger.info("Using injected KV Cache (direct)") + injected_kv = NaiveCache.from_object(injected_kv) gen_context["past_key_values"] = injected_kv seq_len = injected_kv.key_cache[0].shape[0] gen_context["kv_lens"] = [seq_len] @@ -431,6 +432,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: ) if cfg_text_kv is not None: + cfg_text_kv = NaiveCache.from_object(cfg_text_kv) cfg_text_seq_len = cfg_text_kv.key_cache[0].shape[0] cfg_text_context["past_key_values"] = cfg_text_kv cfg_text_context["kv_lens"] = [cfg_text_seq_len] @@ -458,6 +460,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: else: cfg_img_context["ropes"] = [cfg_img_seq_len] else: + cfg_img_kv = NaiveCache.from_object(cfg_img_kv) cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0] cfg_img_context["past_key_values"] = cfg_img_kv cfg_img_context["kv_lens"] = [cfg_img_seq_len] @@ -553,6 +556,8 @@ def vae_transforms(img): for k, v in gen_input_img.items(): if torch.is_tensor(v): gen_input_img[k] = v.to(self.device) + for k in ("packed_indexes", "packed_key_value_indexes", "key_values_lens"): + gen_input_img.pop(k, None) with torch.autocast( device_type=self.device.type, enabled=self.device.type != "cpu", @@ -604,7 +609,7 @@ def vae_transforms(img): # cfg_text_context: update with negative prompt (no text condition). # When empty, keep cfg_text_context as-is (kv_lens=0) to match - # original BAGEL; _merge_naive_caches handles None KV entries. + # original BAGEL. neg_prompt = extra_args.get("negative_prompt", "") if neg_prompt: neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts( diff --git a/vllm_omni/diffusion/utils/kv_utils.py b/vllm_omni/diffusion/utils/kv_utils.py new file mode 100644 index 00000000000..78cfd524b1b --- /dev/null +++ b/vllm_omni/diffusion/utils/kv_utils.py @@ -0,0 +1,39 @@ +"""Utilities for batching variable-length tensors in diffusion attention.""" + +import torch + + +def left_pad_stack( + tensors: list[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Left-pad and stack variable-length tensors. Only dim 0 may vary, + and it's assumed that all tensors are the same dtype & use the same + device. + + Returns (stacked, mask) where mask is a 2D boolean mask, and both + tensors are on the device of the provided tensors. If all tensors are + the same length, None is returned for mask. + """ + trailing_dims = set([ts.shape[1:] for ts in tensors]) + if len(trailing_dims) != 1: + raise ValueError("Tensors must be non-empty and can only vary in dim 0") + trailing = trailing_dims.pop() + + seq_lens = [ts.shape[0] for ts in tensors] + max_len = max(seq_lens) + + # If everything is the same length, we don't need a mask / varlen + if all(sl == max_len for sl in seq_lens): + return torch.stack(tensors), None + + device = tensors[0].device + dtype = tensors[0].dtype + stacked = torch.zeros(len(tensors), max_len, *trailing, dtype=dtype, device=device) + # Create the boolean mask for the input sequences + mask = torch.zeros(len(tensors), max_len, dtype=torch.bool, device=device) + for idx, (ts, sl) in enumerate(zip(tensors, seq_lens)): + pad = max_len - sl + stacked[idx, pad:] = ts + mask[idx, pad:] = True + + return stacked, mask