From 7ec62e8e2d3c7a22d71980f4d979314b1e3a8fb6 Mon Sep 17 00:00:00 2001 From: Ding Zuhao Date: Mon, 2 Mar 2026 14:46:47 +0800 Subject: [PATCH 1/4] feat(bagel): add CFG parallel mode for distributed denoising Add parallel CFG denoising path where 3 branches (gen, text_cfg, img_cfg) are distributed across GPUs via cfg_parallel infrastructure. - Extract _combine_cfg() for reusable CFG combination logic with renorm - Add _generate_image_parallel() for multi-GPU denoising loop - Support cfg_parallel_size=1 (batched), 2 (text CFG only), 3 (all branches) - Add validation guards for cfg_parallel_size vs cfg_img_scale consistency - Relax cfg_parallel_size validation in DiffusionParallelConfig to allow [1,2,3] Signed-off-by: Ding Zuhao --- vllm_omni/diffusion/data.py | 4 +- .../models/bagel/bagel_transformer.py | 342 ++++++++++++++++-- 2 files changed, 314 insertions(+), 32 deletions(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index bd5af226373..da68adbdb17 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -77,7 +77,9 @@ def _validate_parallel_config(self) -> Self: assert self.ulysses_degree > 0, "Ulysses degree must be > 0" assert self.ring_degree > 0, "Ring degree must be > 0" assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0" - assert self.cfg_parallel_size in [1, 2], f"CFG parallel size must be 1 or 2, but got {self.cfg_parallel_size}" + assert self.cfg_parallel_size in [1, 2, 3], ( + f"CFG parallel size must be 1, 2, or 3, but got {self.cfg_parallel_size}" + ) assert self.vae_patch_parallel_size > 0, "VAE patch parallel size must be > 0" assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, ( "Sequence parallel size must be equal to the product of ulysses degree and ring degree," diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 541a037165a..e715bdaec8c 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -32,6 +32,11 @@ from vllm.transformers_utils.configs.bagel import BagelConfig from vllm_omni.diffusion.attention.backends.utils.fa import flash_attn_varlen_func +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) from vllm_omni.diffusion.layers.rope import RotaryEmbedding @@ -1330,7 +1335,6 @@ def generate_image( cfg_img_past_key_values: NaiveCache | None = None, cfg_img_key_values_lens: torch.IntTensor | None = None, cfg_img_packed_key_value_indexes: torch.LongTensor | None = None, - cfg_type: str = "parallel", ): x_t = packed_init_noises @@ -1339,9 +1343,45 @@ def generate_image( dts = timesteps[:-1] - timesteps[1:] timesteps = timesteps[:-1] - # ── Pre-compute batched CFG state (merged caches + indices) ── use_cfg_text = cfg_text_scale > 1.0 use_cfg_img = cfg_img_scale > 1.0 + + # ── Detect CFG parallel mode ── + cfg_parallel_ready = use_cfg_text and get_classifier_free_guidance_world_size() > 1 + + if cfg_parallel_ready: + return self._generate_image_parallel( + x_t=x_t, + timesteps=timesteps, + dts=dts, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_seqlens=packed_seqlens, + packed_position_ids=packed_position_ids, + packed_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + cfg_renorm_min=cfg_renorm_min, + cfg_renorm_type=cfg_renorm_type, + cfg_interval=cfg_interval, + cfg_text_scale=cfg_text_scale, + cfg_text_packed_query_indexes=cfg_text_packed_query_indexes, + cfg_text_packed_position_ids=cfg_text_packed_position_ids, + cfg_text_past_key_values=cfg_text_past_key_values, + cfg_text_key_values_lens=cfg_text_key_values_lens, + cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes, + cfg_img_scale=cfg_img_scale, + cfg_img_packed_query_indexes=cfg_img_packed_query_indexes, + cfg_img_packed_position_ids=cfg_img_packed_position_ids, + cfg_img_past_key_values=cfg_img_past_key_values, + cfg_img_key_values_lens=cfg_img_key_values_lens, + cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes, + ) + + # ── Batched CFG mode (cfg_parallel_size=1) ── cfg_batched = None if use_cfg_text: @@ -1430,6 +1470,266 @@ def generate_image( unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) return unpacked_latent + def _generate_image_parallel( + self, + x_t: torch.Tensor, + timesteps: torch.Tensor, + dts: torch.Tensor, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_vae_position_ids: torch.LongTensor, + packed_vae_token_indexes: torch.LongTensor, + packed_seqlens: torch.IntTensor, + packed_position_ids: torch.LongTensor, + packed_indexes: torch.LongTensor, + past_key_values: NaiveCache, + key_values_lens: torch.IntTensor, + packed_key_value_indexes: torch.LongTensor, + cfg_renorm_min: float, + cfg_renorm_type: str, + cfg_interval: tuple[float, float], + cfg_text_scale: float, + cfg_text_packed_query_indexes: torch.LongTensor | None, + cfg_text_packed_position_ids: torch.LongTensor | None, + cfg_text_past_key_values: NaiveCache | None, + cfg_text_key_values_lens: torch.IntTensor | None, + cfg_text_packed_key_value_indexes: torch.LongTensor | None, + cfg_img_scale: float, + cfg_img_packed_query_indexes: torch.LongTensor | None, + cfg_img_packed_position_ids: torch.LongTensor | None, + cfg_img_past_key_values: NaiveCache | None, + cfg_img_key_values_lens: torch.IntTensor | None, + cfg_img_packed_key_value_indexes: torch.LongTensor | None, + ): + """CFG parallel denoising loop: each rank computes one CFG branch. + + Rank 0: gen branch (full conditioning) + Rank 1: text_cfg branch (unconditional text) + Rank 2: img_cfg branch (no image condition), only when cfg_img_scale > 1.0 + """ + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + cfg_world_size = get_classifier_free_guidance_world_size() + use_cfg_img = cfg_img_scale > 1.0 + + # Validate cfg_parallel_size vs cfg_img_scale consistency + if cfg_world_size == 3 and not use_cfg_img: + raise ValueError( + f"cfg_parallel_size=3 requires cfg_img_scale > 1.0, " + f"but got cfg_img_scale={cfg_img_scale}. " + f"Use cfg_parallel_size=2 for text-only CFG parallel(text2img), or set cfg_img_scale > 1.0." + ) + if cfg_world_size == 2 and use_cfg_img: + raise ValueError( + f"Image CFG (cfg_img_scale={cfg_img_scale}) requires cfg_parallel_size=3, " + f"but got cfg_parallel_size=2. " + f"Use cfg_parallel_size=3 to enable image CFG in parallel mode." + ) + + # Select this rank's branch inputs + if cfg_rank == 0: + # Gen branch: use main inputs directly + branch_position_ids = packed_position_ids + branch_indexes = packed_indexes + branch_past_key_values = past_key_values + branch_key_values_lens = key_values_lens + branch_key_value_indexes = packed_key_value_indexes + elif cfg_rank == 1: + # Text CFG branch + branch_position_ids = cfg_text_packed_position_ids + branch_indexes = cfg_text_packed_query_indexes + branch_past_key_values = cfg_text_past_key_values + branch_key_values_lens = cfg_text_key_values_lens + branch_key_value_indexes = cfg_text_packed_key_value_indexes + elif cfg_rank == 2: + # Image CFG branch + branch_position_ids = cfg_img_packed_position_ids + branch_indexes = cfg_img_packed_query_indexes + branch_past_key_values = cfg_img_past_key_values + branch_key_values_lens = cfg_img_key_values_lens + branch_key_value_indexes = cfg_img_packed_key_value_indexes + else: + raise RuntimeError(f"Unexpected cfg_rank={cfg_rank} for Bagel 3-branch CFG parallel") + + for i, t in enumerate(timesteps): + timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) + if t > cfg_interval[0] and t <= cfg_interval[1]: + cfg_text_scale_ = cfg_text_scale + cfg_img_scale_ = cfg_img_scale + else: + cfg_text_scale_ = 1.0 + cfg_img_scale_ = 1.0 + + use_cfg_this_step = cfg_text_scale_ > 1.0 + + if use_cfg_this_step: + # Each rank computes its branch's velocity + local_v_t = self._forward_flow_single_branch( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_indexes=branch_indexes, + packed_position_ids=branch_position_ids, + packed_seqlens=packed_seqlens, + key_values_lens=branch_key_values_lens, + past_key_values=branch_past_key_values, + packed_key_value_indexes=branch_key_value_indexes, + ) + + # All-gather velocities from all CFG ranks + gathered = cfg_group.all_gather(local_v_t, separate_tensors=True) + + # Rank 0 combines with CFG formula + if cfg_rank == 0: + v_t = gathered[0] # gen branch + cfg_text_v_t = gathered[1] # text_cfg branch + cfg_img_v_t = gathered[2] if (use_cfg_img and len(gathered) > 2) else None + v_t = self._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale_, + cfg_img_scale_, + cfg_renorm_type, + cfg_renorm_min, + ) + x_t = x_t - v_t.to(x_t.device) * dts[i] + else: + # Outside cfg_interval: only rank 0 computes (no CFG needed) + if cfg_rank == 0: + v_t = self._forward_flow_single_branch( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_indexes=packed_indexes, + packed_position_ids=packed_position_ids, + packed_seqlens=packed_seqlens, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + ) + x_t = x_t - v_t.to(x_t.device) * dts[i] + + # Broadcast updated x_t from rank 0 to all ranks + x_t = x_t.contiguous() + cfg_group.broadcast(x_t, src=0) + + unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) + return unpacked_latent + + @staticmethod + def _combine_cfg( + v_t: torch.Tensor, + cfg_text_v_t: torch.Tensor, + cfg_img_v_t: torch.Tensor | None, + cfg_text_scale: float, + cfg_img_scale: float, + cfg_renorm_type: str, + cfg_renorm_min: float, + ) -> torch.Tensor: + """Combine 3-branch CFG predictions with renormalization. + + Args: + v_t: velocity from gen branch (full conditioning) + cfg_text_v_t: velocity from text_cfg branch (unconditional text) + cfg_img_v_t: velocity from img_cfg branch (no image), or None + cfg_text_scale: text guidance scale + cfg_img_scale: image guidance scale + cfg_renorm_type: "text_channel", "global", or "channel" + cfg_renorm_min: minimum renormalization scale + """ + if cfg_renorm_type == "text_channel": + v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) + norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) + norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True) + scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + v_t_text = v_t_text_ * scale + if cfg_img_scale > 1.0 and cfg_img_v_t is not None: + v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t) + else: + v_t = v_t_text + else: + v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) + + if cfg_img_scale > 1.0 and cfg_img_v_t is not None: + v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t) + else: + v_t_ = v_t_text_ + + # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit + if cfg_renorm_type == "global": + norm_v_t = torch.norm(v_t) + norm_v_t_ = torch.norm(v_t_) + elif cfg_renorm_type == "channel": + norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) + norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True) + else: + raise NotImplementedError(f"{cfg_renorm_type} is not supported") + scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + v_t = v_t_ * scale + + return v_t + + def _forward_flow_single_branch( + self, + x_t: torch.Tensor, + timestep: torch.LongTensor, + packed_vae_token_indexes: torch.LongTensor, + packed_vae_position_ids: torch.LongTensor, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_indexes: torch.LongTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + key_values_lens: torch.IntTensor, + past_key_values: NaiveCache, + packed_key_value_indexes: torch.LongTensor, + ) -> torch.Tensor: + """Run a single-branch forward pass (no CFG batching). + + Used by CFG parallel mode where each rank computes one branch. + Returns the velocity v_t for the given branch. + """ + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + assert timestep.unique().shape[0] == 1 + packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) + packed_timestep_embeds = self.time_embedder(timestep) + x_t_emb = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed + if x_t_emb.dtype != packed_sequence.dtype: + x_t_emb = x_t_emb.to(packed_sequence.dtype) + packed_sequence[packed_vae_token_indexes] = x_t_emb + + extra_inputs = {} + if self.use_moe: + extra_inputs["mode"] = "gen" + extra_inputs["packed_vae_token_indexes"] = packed_vae_token_indexes + extra_inputs["packed_text_indexes"] = packed_text_indexes + + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + v_t = self.llm2vae(output.packed_query_sequence) + v_t = v_t[packed_vae_token_indexes] + return v_t + def _forward_flow( self, x_t: torch.Tensor, @@ -1535,34 +1835,14 @@ def _forward_flow( # ── CFG combination ── if use_cfg: - if cfg_renorm_type == "text_channel": - v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) - norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) - norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True) - scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) - v_t_text = v_t_text_ * scale - if cfg_img_scale > 1.0: - v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t) - else: - v_t = v_t_text - else: - v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) - - if cfg_img_scale > 1.0: - v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t) - else: - v_t_ = v_t_text_ - - # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit - if cfg_renorm_type == "global": - norm_v_t = torch.norm(v_t) - norm_v_t_ = torch.norm(v_t_) - elif cfg_renorm_type == "channel": - norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) - norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True) - else: - raise NotImplementedError(f"{cfg_renorm_type} is not supported") - scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) - v_t = v_t_ * scale + v_t = self._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale, + cfg_img_scale, + cfg_renorm_type, + cfg_renorm_min, + ) return v_t From 67dae8a6444f146353f551528fd4c38976e0f727 Mon Sep 17 00:00:00 2001 From: Ding Zuhao Date: Mon, 2 Mar 2026 14:47:00 +0800 Subject: [PATCH 2/4] test+fix: add _combine_cfg unit tests and fix dummy_run CFG defaults - Add 13 unit tests for Bagel._combine_cfg covering all renorm types, CFG scale combinations, and edge cases (CPU-only, no GPU required) - Fix dummy_run in diffusion_engine.py to set cfg_text_scale=1.0 and cfg_img_scale=1.0, preventing CFG parallel validation errors during warmup Signed-off-by: Ding Zuhao --- tests/diffusion/models/bagel/__init__.py | 0 .../models/bagel/test_combine_cfg.py | 314 ++++++++++++++++++ vllm_omni/diffusion/diffusion_engine.py | 3 + 3 files changed, 317 insertions(+) create mode 100644 tests/diffusion/models/bagel/__init__.py create mode 100644 tests/diffusion/models/bagel/test_combine_cfg.py diff --git a/tests/diffusion/models/bagel/__init__.py b/tests/diffusion/models/bagel/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/models/bagel/test_combine_cfg.py b/tests/diffusion/models/bagel/test_combine_cfg.py new file mode 100644 index 00000000000..48be54626b3 --- /dev/null +++ b/tests/diffusion/models/bagel/test_combine_cfg.py @@ -0,0 +1,314 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for Bagel._combine_cfg logic.""" + +import pytest +import torch + +from vllm_omni.diffusion.models.bagel.bagel_transformer import Bagel + + +class TestCombineCfg: + """Tests for the _combine_cfg static method.""" + + def _make_tensors(self, shape=(10, 64), seed=42): + """Create deterministic test tensors.""" + gen = torch.Generator().manual_seed(seed) + v_t = torch.randn(shape, generator=gen) + cfg_text_v_t = torch.randn(shape, generator=gen) + cfg_img_v_t = torch.randn(shape, generator=gen) + return v_t, cfg_text_v_t, cfg_img_v_t + + def test_text_channel_renorm_preserves_direction(self): + """text_channel renorm should change direction but constrain magnitude.""" + v_t, cfg_text_v_t, _ = self._make_tensors() + + result = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + None, + cfg_text_scale=4.0, + cfg_img_scale=1.0, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.0, + ) + + # Result norm per token should be <= original v_t norm (clamp max=1.0) + result_norm = torch.norm(result, dim=-1) + v_t_norm = torch.norm(v_t, dim=-1) + assert torch.all(result_norm <= v_t_norm + 1e-6), "text_channel renorm should not increase per-token norm" + + def test_scale_1_returns_v_t(self): + """cfg_text_scale=1.0 means no CFG: result should equal v_t.""" + v_t, cfg_text_v_t, _ = self._make_tensors() + + result = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + None, + cfg_text_scale=1.0, + cfg_img_scale=1.0, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.0, + ) + + # scale=1 → v_t_text_ = cfg_text + 1*(v_t - cfg_text) = v_t + # renorm scale = norm(v_t)/norm(v_t) = 1.0, so result = v_t + assert torch.allclose(result, v_t, atol=1e-6) + + def test_img_cfg_applied_when_scale_gt_1(self): + """When cfg_img_scale > 1.0, result should differ from text-only CFG.""" + v_t, cfg_text_v_t, cfg_img_v_t = self._make_tensors() + + text_only = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + None, + cfg_text_scale=4.0, + cfg_img_scale=1.0, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.0, + ) + + with_img = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale=4.0, + cfg_img_scale=1.5, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.0, + ) + + assert not torch.allclose(text_only, with_img, atol=1e-6), ( + "Image CFG should produce different result from text-only CFG" + ) + + def test_img_cfg_none_ignored(self): + """cfg_img_v_t=None should be equivalent to cfg_img_scale <= 1.0.""" + v_t, cfg_text_v_t, cfg_img_v_t = self._make_tensors() + + result_none = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + None, + cfg_text_scale=4.0, + cfg_img_scale=1.5, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.0, + ) + + result_low_scale = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale=4.0, + cfg_img_scale=0.5, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.0, + ) + + assert torch.allclose(result_none, result_low_scale, atol=1e-6), ( + "cfg_img_v_t=None and cfg_img_scale<=1.0 should give same result" + ) + + def test_global_renorm(self): + """global renorm should produce valid output without error.""" + v_t, cfg_text_v_t, cfg_img_v_t = self._make_tensors() + + result = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale=4.0, + cfg_img_scale=1.5, + cfg_renorm_type="global", + cfg_renorm_min=0.0, + ) + + assert result.shape == v_t.shape + assert not torch.any(torch.isnan(result)) + + def test_channel_renorm(self): + """channel renorm should produce valid output without error.""" + v_t, cfg_text_v_t, cfg_img_v_t = self._make_tensors() + + result = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale=4.0, + cfg_img_scale=1.5, + cfg_renorm_type="channel", + cfg_renorm_min=0.0, + ) + + assert result.shape == v_t.shape + assert not torch.any(torch.isnan(result)) + + def test_invalid_renorm_type_raises(self): + """Unknown renorm type should raise NotImplementedError.""" + v_t, cfg_text_v_t, _ = self._make_tensors() + + with pytest.raises(NotImplementedError): + Bagel._combine_cfg( + v_t, + cfg_text_v_t, + None, + cfg_text_scale=4.0, + cfg_img_scale=1.0, + cfg_renorm_type="unknown", + cfg_renorm_min=0.0, + ) + + def test_renorm_min_clamps_scale(self): + """cfg_renorm_min should prevent scale from going too low.""" + v_t = torch.ones(10, 64) + # Make cfg_text_v_t very different so CFG amplifies heavily + cfg_text_v_t = torch.zeros(10, 64) + + result_no_min = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + None, + cfg_text_scale=100.0, + cfg_img_scale=1.0, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.0, + ) + + result_with_min = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + None, + cfg_text_scale=100.0, + cfg_img_scale=1.0, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.5, + ) + + # With higher renorm_min, result magnitude should be larger + # (scale is clamped to at least 0.5 instead of going near 0) + norm_no_min = torch.norm(result_no_min) + norm_with_min = torch.norm(result_with_min) + assert norm_with_min >= norm_no_min - 1e-6, "Higher cfg_renorm_min should preserve more magnitude" + + def test_global_renorm_with_img_cfg(self): + """global renorm + img CFG should produce valid, different output.""" + v_t, cfg_text_v_t, cfg_img_v_t = self._make_tensors() + + text_only = Bagel._combine_cfg( + v_t.clone(), + cfg_text_v_t.clone(), + None, + cfg_text_scale=4.0, + cfg_img_scale=1.0, + cfg_renorm_type="global", + cfg_renorm_min=0.0, + ) + + with_img = Bagel._combine_cfg( + v_t.clone(), + cfg_text_v_t.clone(), + cfg_img_v_t.clone(), + cfg_text_scale=4.0, + cfg_img_scale=1.5, + cfg_renorm_type="global", + cfg_renorm_min=0.0, + ) + + assert not torch.allclose(text_only, with_img, atol=1e-6), ( + "global renorm + img CFG should differ from text-only" + ) + assert not torch.any(torch.isnan(with_img)) + + def test_channel_renorm_with_img_cfg(self): + """channel renorm + img CFG should produce valid, different output.""" + v_t, cfg_text_v_t, cfg_img_v_t = self._make_tensors() + + text_only = Bagel._combine_cfg( + v_t.clone(), + cfg_text_v_t.clone(), + None, + cfg_text_scale=4.0, + cfg_img_scale=1.0, + cfg_renorm_type="channel", + cfg_renorm_min=0.0, + ) + + with_img = Bagel._combine_cfg( + v_t.clone(), + cfg_text_v_t.clone(), + cfg_img_v_t.clone(), + cfg_text_scale=4.0, + cfg_img_scale=1.5, + cfg_renorm_type="channel", + cfg_renorm_min=0.0, + ) + + assert not torch.allclose(text_only, with_img, atol=1e-6), ( + "channel renorm + img CFG should differ from text-only" + ) + assert not torch.any(torch.isnan(with_img)) + + def test_global_channel_renorm_constrains_norm(self): + """global and channel renorm should not increase overall norm.""" + v_t, cfg_text_v_t, cfg_img_v_t = self._make_tensors() + + for renorm_type in ["global", "channel"]: + result = Bagel._combine_cfg( + v_t.clone(), + cfg_text_v_t.clone(), + cfg_img_v_t.clone(), + cfg_text_scale=4.0, + cfg_img_scale=1.5, + cfg_renorm_type=renorm_type, + cfg_renorm_min=0.0, + ) + # Global norm of result should be <= global norm of v_t (clamp max=1.0) + assert torch.norm(result) <= torch.norm(v_t) + 1e-5, f"{renorm_type} renorm should not increase global norm" + + def test_text_channel_img_cfg_no_second_renorm(self): + """text_channel mode: img CFG is applied AFTER renorm, without a second renorm. + So the result norm can exceed v_t norm when img_scale > 1.""" + v_t, cfg_text_v_t, cfg_img_v_t = self._make_tensors() + + result = Bagel._combine_cfg( + v_t, + cfg_text_v_t, + cfg_img_v_t, + cfg_text_scale=4.0, + cfg_img_scale=2.0, + cfg_renorm_type="text_channel", + cfg_renorm_min=0.0, + ) + + # text_channel renorms after text CFG, then applies img CFG without renorm + # So result norm CAN exceed v_t norm — this is expected behavior + assert result.shape == v_t.shape + assert not torch.any(torch.isnan(result)) + + def test_all_renorm_types_consistent_direction(self): + """All renorm types should guide in the same general direction.""" + v_t, cfg_text_v_t, _ = self._make_tensors() + + results = {} + for renorm_type in ["text_channel", "global", "channel"]: + results[renorm_type] = Bagel._combine_cfg( + v_t.clone(), + cfg_text_v_t.clone(), + None, + cfg_text_scale=4.0, + cfg_img_scale=1.0, + cfg_renorm_type=renorm_type, + cfg_renorm_min=0.0, + ) + + # All results should have positive cosine similarity with each other + for a_name, a in results.items(): + for b_name, b in results.items(): + cos_sim = torch.nn.functional.cosine_similarity(a.flatten().unsqueeze(0), b.flatten().unsqueeze(0)) + assert cos_sim > 0.5, ( + f"{a_name} and {b_name} should point in similar direction, " + f"but cosine similarity = {cos_sim.item():.4f}" + ) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 8e4a9f7a20f..04c68c9ce86 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -336,6 +336,9 @@ def _dummy_run(self): # classifier-free guidance with an empty negative prompt. guidance_scale=0.0, num_outputs_per_prompt=1, + # Disable CFG for warmup to avoid triggering CFG parallel + # validation when cfg_parallel_size > 1. + extra_args={"cfg_text_scale": 1.0, "cfg_img_scale": 1.0}, ), ) logger.info("dummy run to warm up the model") From f6556f2ad3bfc5b071bea36bbbc0fa3db23a9a59 Mon Sep 17 00:00:00 2001 From: Ding Zuhao Date: Mon, 2 Mar 2026 14:47:07 +0800 Subject: [PATCH 3/4] examples: update end2end.py with CFG parallel and negative prompt support - Add --cfg-parallel-size flag for selecting batched vs parallel mode - Add --negative-prompt support for text2img CFG - Always pass parallel_config to OmniDiffusion (cfg_parallel_size=1 default) Signed-off-by: Ding Zuhao --- examples/offline_inference/bagel/end2end.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index 04ae4c15d6d..584e463e154 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -51,6 +51,13 @@ def parse_args(): parser.add_argument( "--negative-prompt", type=str, default=None, help="Negative prompt for CFG (default: empty prompt)" ) + parser.add_argument( + "--cfg-parallel-size", + type=int, + default=1, + choices=[1, 2, 3], + help="CFG parallel size: 1=batched (single GPU), 2=parallel with 2 branches (text CFG only), 3=parallel (3 GPUs).", + ) args = parser.parse_args() return args @@ -82,12 +89,14 @@ def main(): from PIL import Image if args.modality == "img2img": - from PIL import Image - from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion - print("[Info] Running in img2img mode (Stage 1 only)") - client = OmniDiffusion(model=model_name) + print(f"[Info] Running in {args.modality} mode (Stage 1 only, cfg_parallel_size={args.cfg_parallel_size})") + + client = OmniDiffusion( + model=model_name, + parallel_config={"cfg_parallel_size": args.cfg_parallel_size}, + ) if args.image_path: if os.path.exists(args.image_path): From b808e55900f52466944b027ad3090ccc740fc635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Thu, 5 Mar 2026 23:00:06 +0800 Subject: [PATCH 4/4] Add pytest markers for core model and CPU tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- tests/diffusion/models/bagel/test_combine_cfg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/diffusion/models/bagel/test_combine_cfg.py b/tests/diffusion/models/bagel/test_combine_cfg.py index 48be54626b3..88611fdc454 100644 --- a/tests/diffusion/models/bagel/test_combine_cfg.py +++ b/tests/diffusion/models/bagel/test_combine_cfg.py @@ -7,6 +7,8 @@ from vllm_omni.diffusion.models.bagel.bagel_transformer import Bagel +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + class TestCombineCfg: """Tests for the _combine_cfg static method."""