diff --git a/tests/e2e/offline_inference/test_bagel_img2img.py b/tests/e2e/offline_inference/test_bagel_img2img.py index 63d2a37da79..be79aa7348a 100644 --- a/tests/e2e/offline_inference/test_bagel_img2img.py +++ b/tests/e2e/offline_inference/test_bagel_img2img.py @@ -32,30 +32,30 @@ # prompt='Change the grass color to red', # input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (157, 172, 217)}, - {"position": (400, 50), "rgb": (105, 144, 218)}, - {"position": (700, 100), "rgb": (118, 159, 233)}, - {"position": (150, 400), "rgb": (195, 34, 60)}, - {"position": (512, 336), "rgb": (222, 214, 193)}, - {"position": (700, 400), "rgb": (197, 15, 43)}, - {"position": (100, 600), "rgb": (105, 13, 18)}, - {"position": (400, 600), "rgb": (169, 33, 44)}, - {"position": (700, 600), "rgb": (101, 86, 93)}, - {"position": (256, 256), "rgb": (181, 202, 222)}, + {"position": (100, 100), "rgb": (156, 172, 217)}, + {"position": (400, 50), "rgb": (105, 144, 217)}, + {"position": (700, 100), "rgb": (118, 159, 232)}, + {"position": (150, 400), "rgb": (180, 22, 52)}, + {"position": (512, 336), "rgb": (221, 211, 194)}, + {"position": (700, 400), "rgb": (192, 10, 46)}, + {"position": (100, 600), "rgb": (102, 12, 22)}, + {"position": (400, 600), "rgb": (161, 28, 47)}, + {"position": (700, 600), "rgb": (100, 87, 94)}, + {"position": (256, 256), "rgb": (181, 201, 221)}, ] if current_omni_platform.is_rocm(): REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (156, 172, 215)}, - {"position": (400, 50), "rgb": (106, 144, 216)}, - {"position": (700, 100), "rgb": (118, 158, 231)}, - {"position": (150, 400), "rgb": (183, 23, 48)}, - {"position": (512, 336), "rgb": (218, 215, 191)}, - {"position": (700, 400), "rgb": (194, 14, 42)}, - {"position": (100, 600), "rgb": (105, 10, 16)}, - {"position": (400, 600), "rgb": (167, 33, 46)}, - {"position": (700, 600), "rgb": (102, 86, 92)}, - {"position": (256, 256), "rgb": (181, 201, 220)}, + {"position": (100, 100), "rgb": (156, 172, 217)}, + {"position": (400, 50), "rgb": (105, 144, 217)}, + {"position": (700, 100), "rgb": (118, 159, 232)}, + {"position": (150, 400), "rgb": (180, 22, 52)}, + {"position": (512, 336), "rgb": (221, 211, 194)}, + {"position": (700, 400), "rgb": (192, 10, 46)}, + {"position": (100, 600), "rgb": (102, 12, 22)}, + {"position": (400, 600), "rgb": (161, 28, 47)}, + {"position": (700, 600), "rgb": (100, 87, 94)}, + {"position": (256, 256), "rgb": (181, 201, 221)}, ] PIXEL_TOLERANCE = 10 diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py index e45d64f2ac5..534b8730682 100644 --- a/tests/e2e/offline_inference/test_bagel_text2img.py +++ b/tests/e2e/offline_inference/test_bagel_text2img.py @@ -37,30 +37,30 @@ # "Generated with seed=52, num_inference_steps=15, # prompt='A futuristic city skyline at twilight, cyberpunk style'" REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (121, 118, 100)}, - {"position": (400, 50), "rgb": (163, 162, 143)}, - {"position": (700, 100), "rgb": (170, 156, 127)}, - {"position": (150, 400), "rgb": (129, 127, 112)}, - {"position": (512, 512), "rgb": (135, 61, 59)}, - {"position": (700, 400), "rgb": (205, 107, 43)}, - {"position": (100, 700), "rgb": (197, 177, 157)}, - {"position": (400, 700), "rgb": (139, 107, 86)}, - {"position": (700, 700), "rgb": (247, 205, 146)}, - {"position": (256, 256), "rgb": (171, 160, 153)}, + {"position": (100, 100), "rgb": (115, 113, 94)}, + {"position": (400, 50), "rgb": (159, 160, 144)}, + {"position": (700, 100), "rgb": (164, 151, 123)}, + {"position": (150, 400), "rgb": (120, 121, 107)}, + {"position": (512, 512), "rgb": (165, 133, 127)}, + {"position": (700, 400), "rgb": (217, 130, 66)}, + {"position": (100, 700), "rgb": (191, 168, 152)}, + {"position": (400, 700), "rgb": (130, 96, 77)}, + {"position": (700, 700), "rgb": (247, 203, 140)}, + {"position": (256, 256), "rgb": (167, 156, 150)}, ] if current_omni_platform.is_rocm(): REFERENCE_PIXELS = [ - {"position": (100, 100), "rgb": (123, 119, 100)}, - {"position": (400, 50), "rgb": (162, 161, 142)}, - {"position": (700, 100), "rgb": (171, 156, 127)}, - {"position": (150, 400), "rgb": (131, 128, 112)}, - {"position": (512, 512), "rgb": (134, 61, 59)}, - {"position": (700, 400), "rgb": (204, 107, 43)}, - {"position": (100, 700), "rgb": (201, 180, 165)}, - {"position": (400, 700), "rgb": (140, 108, 87)}, - {"position": (700, 700), "rgb": (247, 205, 145)}, - {"position": (256, 256), "rgb": (171, 160, 153)}, + {"position": (100, 100), "rgb": (115, 113, 94)}, + {"position": (400, 50), "rgb": (159, 160, 144)}, + {"position": (700, 100), "rgb": (164, 151, 123)}, + {"position": (150, 400), "rgb": (120, 121, 107)}, + {"position": (512, 512), "rgb": (165, 133, 127)}, + {"position": (700, 400), "rgb": (217, 130, 66)}, + {"position": (100, 700), "rgb": (191, 168, 152)}, + {"position": (400, 700), "rgb": (130, 96, 77)}, + {"position": (700, 700), "rgb": (247, 203, 140)}, + {"position": (256, 256), "rgb": (167, 156, 150)}, ] # Maximum allowed difference per color channel diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index a3d2259e643..90baf5f6761 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -397,11 +397,26 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: cfg_text_context["ropes"] = cfg_text_metadata["ropes"] else: cfg_text_context["ropes"] = [cfg_text_seq_len] - - if cfg_img_kv is None and cfg_text_kv is not None: - cfg_img_kv = injected_kv - - if cfg_img_kv is not None: + else: + # No cfg_text companion received. For text2img this is the + # expected path: original BAGEL uses an empty KV cache (0 + # tokens) as the text-unconditional branch. Keep the default + # empty NaiveCache in cfg_text_context and preserve the + # original cfg_text_scale so CFG still applies. + pass + + if cfg_img_kv is None: + # text2img multi-stage: cfg_img reuses gen KV (positive prompt, + # no image), mirroring forward_cache_update_text on cfg_img_context + # in the single-stage path. + cfg_img_seq_len = injected_kv.key_cache[0].shape[0] + cfg_img_context["past_key_values"] = injected_kv + cfg_img_context["kv_lens"] = [cfg_img_seq_len] + if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata: + cfg_img_context["ropes"] = req.sampling_params.kv_metadata["ropes"] + else: + cfg_img_context["ropes"] = [cfg_img_seq_len] + else: cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0] cfg_img_context["past_key_values"] = cfg_img_kv cfg_img_context["kv_lens"] = [cfg_img_seq_len] @@ -410,15 +425,6 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: else: cfg_img_context["ropes"] = [cfg_img_seq_len] - if not cfg_parallel_contract: - logger.warning("CFG is disabled: only single KV cache available") - gen_params = BagelGenParams( - num_timesteps=gen_params.num_timesteps, - timestep_shift=gen_params.timestep_shift, - cfg_text_scale=1.0, - cfg_img_scale=1.0, - ) - else: image_input = ( None diff --git a/vllm_omni/model_executor/stage_input_processors/bagel.py b/vllm_omni/model_executor/stage_input_processors/bagel.py index bfcff0ea0f3..52cc14d3aa2 100644 --- a/vllm_omni/model_executor/stage_input_processors/bagel.py +++ b/vllm_omni/model_executor/stage_input_processors/bagel.py @@ -82,6 +82,8 @@ def expand_cfg_prompts( neg_prompt = _get_negative_prompt(prompt, sampling_params) if "image" in modalities: + if not neg_prompt: + return [] neg_prompt_dict = { "prompt": neg_prompt, "modalities": prompt.get("modalities", []), @@ -166,6 +168,8 @@ def expand_cfg_prompts_think( companion_params = {"max_tokens": 1} if "image" in modalities: + if not neg_prompt: + return [] neg_prompt_dict = { "prompt": neg_prompt, "modalities": prompt.get("modalities", []), @@ -287,9 +291,10 @@ def _get_negative_prompt( ) -> str: """Resolve the negative prompt for CFG from prompt or sampling params. - An empty string is treated the same as absent (falls through to - the Bagel default token pair), because an empty negative prompt is - not meaningful for CFG guidance. + Returns the negative prompt string when one is supplied, otherwise an + empty string. Callers decide how to treat the empty case: text2img + skips the cfg_text companion entirely, while img2img substitutes it + into the cfg_text prompt template. """ neg = prompt.get("negative_prompt") if neg: @@ -300,4 +305,4 @@ def _get_negative_prompt( if neg: return neg - return "<|im_start|><|im_end|>" + return ""