Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions tests/e2e/offline_inference/test_bagel_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,30 @@
# prompt='Change the grass color to red',
# input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (157, 172, 217)},
{"position": (400, 50), "rgb": (105, 144, 218)},
{"position": (700, 100), "rgb": (118, 159, 233)},
{"position": (150, 400), "rgb": (195, 34, 60)},
{"position": (512, 336), "rgb": (222, 214, 193)},
{"position": (700, 400), "rgb": (197, 15, 43)},
{"position": (100, 600), "rgb": (105, 13, 18)},
{"position": (400, 600), "rgb": (169, 33, 44)},
{"position": (700, 600), "rgb": (101, 86, 93)},
{"position": (256, 256), "rgb": (181, 202, 222)},
{"position": (100, 100), "rgb": (156, 172, 217)},
{"position": (400, 50), "rgb": (105, 144, 217)},
{"position": (700, 100), "rgb": (118, 159, 232)},
{"position": (150, 400), "rgb": (180, 22, 52)},
{"position": (512, 336), "rgb": (221, 211, 194)},
{"position": (700, 400), "rgb": (192, 10, 46)},
{"position": (100, 600), "rgb": (102, 12, 22)},
{"position": (400, 600), "rgb": (161, 28, 47)},
{"position": (700, 600), "rgb": (100, 87, 94)},
{"position": (256, 256), "rgb": (181, 201, 221)},
]

if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (156, 172, 215)},
{"position": (400, 50), "rgb": (106, 144, 216)},
{"position": (700, 100), "rgb": (118, 158, 231)},
{"position": (150, 400), "rgb": (183, 23, 48)},
{"position": (512, 336), "rgb": (218, 215, 191)},
{"position": (700, 400), "rgb": (194, 14, 42)},
{"position": (100, 600), "rgb": (105, 10, 16)},
{"position": (400, 600), "rgb": (167, 33, 46)},
{"position": (700, 600), "rgb": (102, 86, 92)},
{"position": (256, 256), "rgb": (181, 201, 220)},
{"position": (100, 100), "rgb": (156, 172, 217)},
{"position": (400, 50), "rgb": (105, 144, 217)},
{"position": (700, 100), "rgb": (118, 159, 232)},
{"position": (150, 400), "rgb": (180, 22, 52)},
{"position": (512, 336), "rgb": (221, 211, 194)},
{"position": (700, 400), "rgb": (192, 10, 46)},
{"position": (100, 600), "rgb": (102, 12, 22)},
{"position": (400, 600), "rgb": (161, 28, 47)},
{"position": (700, 600), "rgb": (100, 87, 94)},
{"position": (256, 256), "rgb": (181, 201, 221)},
]

PIXEL_TOLERANCE = 10
Expand Down
40 changes: 20 additions & 20 deletions tests/e2e/offline_inference/test_bagel_text2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,30 @@
# "Generated with seed=52, num_inference_steps=15,
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (121, 118, 100)},
{"position": (400, 50), "rgb": (163, 162, 143)},
{"position": (700, 100), "rgb": (170, 156, 127)},
{"position": (150, 400), "rgb": (129, 127, 112)},
{"position": (512, 512), "rgb": (135, 61, 59)},
{"position": (700, 400), "rgb": (205, 107, 43)},
{"position": (100, 700), "rgb": (197, 177, 157)},
{"position": (400, 700), "rgb": (139, 107, 86)},
{"position": (700, 700), "rgb": (247, 205, 146)},
{"position": (256, 256), "rgb": (171, 160, 153)},
{"position": (100, 100), "rgb": (115, 113, 94)},
{"position": (400, 50), "rgb": (159, 160, 144)},
{"position": (700, 100), "rgb": (164, 151, 123)},
{"position": (150, 400), "rgb": (120, 121, 107)},
{"position": (512, 512), "rgb": (165, 133, 127)},
{"position": (700, 400), "rgb": (217, 130, 66)},
{"position": (100, 700), "rgb": (191, 168, 152)},
{"position": (400, 700), "rgb": (130, 96, 77)},
{"position": (700, 700), "rgb": (247, 203, 140)},
{"position": (256, 256), "rgb": (167, 156, 150)},
]

if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (123, 119, 100)},
{"position": (400, 50), "rgb": (162, 161, 142)},
{"position": (700, 100), "rgb": (171, 156, 127)},
{"position": (150, 400), "rgb": (131, 128, 112)},
{"position": (512, 512), "rgb": (134, 61, 59)},
{"position": (700, 400), "rgb": (204, 107, 43)},
{"position": (100, 700), "rgb": (201, 180, 165)},
{"position": (400, 700), "rgb": (140, 108, 87)},
{"position": (700, 700), "rgb": (247, 205, 145)},
{"position": (256, 256), "rgb": (171, 160, 153)},
{"position": (100, 100), "rgb": (115, 113, 94)},
{"position": (400, 50), "rgb": (159, 160, 144)},
{"position": (700, 100), "rgb": (164, 151, 123)},
{"position": (150, 400), "rgb": (120, 121, 107)},
{"position": (512, 512), "rgb": (165, 133, 127)},
{"position": (700, 400), "rgb": (217, 130, 66)},
{"position": (100, 700), "rgb": (191, 168, 152)},
{"position": (400, 700), "rgb": (130, 96, 77)},
{"position": (700, 700), "rgb": (247, 203, 140)},
{"position": (256, 256), "rgb": (167, 156, 150)},
]

# Maximum allowed difference per color channel
Expand Down
34 changes: 20 additions & 14 deletions vllm_omni/diffusion/models/bagel/pipeline_bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,26 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
cfg_text_context["ropes"] = cfg_text_metadata["ropes"]
else:
cfg_text_context["ropes"] = [cfg_text_seq_len]

if cfg_img_kv is None and cfg_text_kv is not None:
cfg_img_kv = injected_kv

if cfg_img_kv is not None:
else:
# No cfg_text companion received. For text2img this is the
# expected path: original BAGEL uses an empty KV cache (0
# tokens) as the text-unconditional branch. Keep the default
# empty NaiveCache in cfg_text_context and preserve the
# original cfg_text_scale so CFG still applies.
pass

if cfg_img_kv is None:
# text2img multi-stage: cfg_img reuses gen KV (positive prompt,
# no image), mirroring forward_cache_update_text on cfg_img_context
# in the single-stage path.
cfg_img_seq_len = injected_kv.key_cache[0].shape[0]
cfg_img_context["past_key_values"] = injected_kv
cfg_img_context["kv_lens"] = [cfg_img_seq_len]
if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata:
cfg_img_context["ropes"] = req.sampling_params.kv_metadata["ropes"]
else:
cfg_img_context["ropes"] = [cfg_img_seq_len]
else:
cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0]
cfg_img_context["past_key_values"] = cfg_img_kv
cfg_img_context["kv_lens"] = [cfg_img_seq_len]
Expand All @@ -410,15 +425,6 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
else:
cfg_img_context["ropes"] = [cfg_img_seq_len]

if not cfg_parallel_contract:
logger.warning("CFG is disabled: only single KV cache available")
gen_params = BagelGenParams(
num_timesteps=gen_params.num_timesteps,
timestep_shift=gen_params.timestep_shift,
cfg_text_scale=1.0,
cfg_img_scale=1.0,
)

else:
image_input = (
None
Expand Down
13 changes: 9 additions & 4 deletions vllm_omni/model_executor/stage_input_processors/bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def expand_cfg_prompts(
neg_prompt = _get_negative_prompt(prompt, sampling_params)

if "image" in modalities:
if not neg_prompt:
return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
Expand Down Expand Up @@ -166,6 +168,8 @@ def expand_cfg_prompts_think(
companion_params = {"max_tokens": 1}

if "image" in modalities:
if not neg_prompt:
return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
Expand Down Expand Up @@ -287,9 +291,10 @@ def _get_negative_prompt(
) -> str:
"""Resolve the negative prompt for CFG from prompt or sampling params.

An empty string is treated the same as absent (falls through to
the Bagel default token pair), because an empty negative prompt is
not meaningful for CFG guidance.
Returns the negative prompt string when one is supplied, otherwise an
empty string. Callers decide how to treat the empty case: text2img
skips the cfg_text companion entirely, while img2img substitutes it
into the cfg_text prompt template.
"""
neg = prompt.get("negative_prompt")
if neg:
Expand All @@ -300,4 +305,4 @@ def _get_negative_prompt(
if neg:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring above (line 293-296) is stale — still says "falls through to the Bagel default token pair" but we now return "". Please update.

return neg

return "<|im_start|><|im_end|>"
return ""
Loading