Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions tests/e2e/offline_inference/test_bagel_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,30 @@
# prompt='Change the grass color to red',
# input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (157, 172, 217)},
{"position": (400, 50), "rgb": (105, 144, 218)},
{"position": (700, 100), "rgb": (118, 159, 233)},
{"position": (150, 400), "rgb": (195, 34, 60)},
{"position": (512, 336), "rgb": (222, 214, 193)},
{"position": (700, 400), "rgb": (197, 15, 43)},
{"position": (100, 600), "rgb": (105, 13, 18)},
{"position": (400, 600), "rgb": (169, 33, 44)},
{"position": (700, 600), "rgb": (101, 86, 93)},
{"position": (256, 256), "rgb": (181, 202, 222)},
{"position": (100, 100), "rgb": (156, 172, 217)},
{"position": (400, 50), "rgb": (105, 144, 217)},
{"position": (700, 100), "rgb": (118, 159, 232)},
{"position": (150, 400), "rgb": (180, 22, 52)},
{"position": (512, 336), "rgb": (221, 211, 194)},
{"position": (700, 400), "rgb": (192, 10, 46)},
{"position": (100, 600), "rgb": (102, 12, 22)},
{"position": (400, 600), "rgb": (161, 28, 47)},
{"position": (700, 600), "rgb": (100, 87, 94)},
{"position": (256, 256), "rgb": (181, 201, 221)},
]

if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (156, 172, 215)},
{"position": (400, 50), "rgb": (106, 144, 216)},
{"position": (700, 100), "rgb": (118, 158, 231)},
{"position": (150, 400), "rgb": (183, 23, 48)},
{"position": (512, 336), "rgb": (218, 215, 191)},
{"position": (700, 400), "rgb": (194, 14, 42)},
{"position": (100, 600), "rgb": (105, 10, 16)},
{"position": (400, 600), "rgb": (167, 33, 46)},
{"position": (700, 600), "rgb": (102, 86, 92)},
{"position": (256, 256), "rgb": (181, 201, 220)},
{"position": (100, 100), "rgb": (156, 172, 217)},
{"position": (400, 50), "rgb": (105, 144, 217)},
{"position": (700, 100), "rgb": (118, 159, 232)},
{"position": (150, 400), "rgb": (180, 22, 52)},
{"position": (512, 336), "rgb": (221, 211, 194)},
{"position": (700, 400), "rgb": (192, 10, 46)},
{"position": (100, 600), "rgb": (102, 12, 22)},
{"position": (400, 600), "rgb": (161, 28, 47)},
{"position": (700, 600), "rgb": (100, 87, 94)},
{"position": (256, 256), "rgb": (181, 201, 221)},
]

PIXEL_TOLERANCE = 10
Expand Down
40 changes: 20 additions & 20 deletions tests/e2e/offline_inference/test_bagel_text2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,30 @@
# "Generated with seed=52, num_inference_steps=15,
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (121, 118, 100)},
{"position": (400, 50), "rgb": (163, 162, 143)},
{"position": (700, 100), "rgb": (170, 156, 127)},
{"position": (150, 400), "rgb": (129, 127, 112)},
{"position": (512, 512), "rgb": (135, 61, 59)},
{"position": (700, 400), "rgb": (205, 107, 43)},
{"position": (100, 700), "rgb": (197, 177, 157)},
{"position": (400, 700), "rgb": (139, 107, 86)},
{"position": (700, 700), "rgb": (247, 205, 146)},
{"position": (256, 256), "rgb": (171, 160, 153)},
{"position": (100, 100), "rgb": (115, 113, 94)},
{"position": (400, 50), "rgb": (159, 160, 144)},
{"position": (700, 100), "rgb": (164, 151, 123)},
{"position": (150, 400), "rgb": (120, 121, 107)},
{"position": (512, 512), "rgb": (165, 133, 127)},
{"position": (700, 400), "rgb": (217, 130, 66)},
{"position": (100, 700), "rgb": (191, 168, 152)},
{"position": (400, 700), "rgb": (130, 96, 77)},
{"position": (700, 700), "rgb": (247, 203, 140)},
{"position": (256, 256), "rgb": (167, 156, 150)},
]

if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (123, 119, 100)},
{"position": (400, 50), "rgb": (162, 161, 142)},
{"position": (700, 100), "rgb": (171, 156, 127)},
{"position": (150, 400), "rgb": (131, 128, 112)},
{"position": (512, 512), "rgb": (134, 61, 59)},
{"position": (700, 400), "rgb": (204, 107, 43)},
{"position": (100, 700), "rgb": (201, 180, 165)},
{"position": (400, 700), "rgb": (140, 108, 87)},
{"position": (700, 700), "rgb": (247, 205, 145)},
{"position": (256, 256), "rgb": (171, 160, 153)},
{"position": (100, 100), "rgb": (115, 113, 94)},
{"position": (400, 50), "rgb": (159, 160, 144)},
{"position": (700, 100), "rgb": (164, 151, 123)},
{"position": (150, 400), "rgb": (120, 121, 107)},
{"position": (512, 512), "rgb": (165, 133, 127)},
{"position": (700, 400), "rgb": (217, 130, 66)},
{"position": (100, 700), "rgb": (191, 168, 152)},
{"position": (400, 700), "rgb": (130, 96, 77)},
{"position": (700, 700), "rgb": (247, 203, 140)},
{"position": (256, 256), "rgb": (167, 156, 150)},
]

# Maximum allowed difference per color channel
Expand Down
20 changes: 9 additions & 11 deletions vllm_omni/diffusion/models/bagel/pipeline_bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,15 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
cfg_text_context["ropes"] = cfg_text_metadata["ropes"]
else:
cfg_text_context["ropes"] = [cfg_text_seq_len]

if cfg_img_kv is None and cfg_text_kv is not None:
else:
# No cfg_text companion received. For text2img this is the
# expected path: original BAGEL uses an empty KV cache (0
# tokens) as the text-unconditional branch. Keep the default
# empty NaiveCache in cfg_text_context and preserve the
# original cfg_text_scale so CFG still applies.
pass

if cfg_img_kv is None:
cfg_img_kv = injected_kv

if cfg_img_kv is not None:
Expand All @@ -410,15 +417,6 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
else:
cfg_img_context["ropes"] = [cfg_img_seq_len]

if not cfg_parallel_contract:
logger.warning("CFG is disabled: only single KV cache available")
gen_params = BagelGenParams(
num_timesteps=gen_params.num_timesteps,
timestep_shift=gen_params.timestep_shift,
cfg_text_scale=1.0,
cfg_img_scale=1.0,
)

else:
image_input = (
None
Expand Down
6 changes: 5 additions & 1 deletion vllm_omni/model_executor/stage_input_processors/bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def expand_cfg_prompts(
neg_prompt = _get_negative_prompt(prompt, sampling_params)

if "image" in modalities:
if not neg_prompt:
return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
Expand Down Expand Up @@ -166,6 +168,8 @@ def expand_cfg_prompts_think(
companion_params = {"max_tokens": 1}

if "image" in modalities:
if not neg_prompt:
return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
Expand Down Expand Up @@ -300,4 +304,4 @@ def _get_negative_prompt(
if neg:
return neg

return "<|im_start|><|im_end|>"
return ""
Loading