Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 67 additions & 31 deletions examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ def parse_args():
default=False,
help="Enable thinking mode: AR stage decodes <think>...</think> planning tokens before image generation.",
)
parser.add_argument(
"--max-think-tokens",
type=int,
default=1000,
help="Maximum number of tokens for thinking text generation (default: 1000).",
)
parser.add_argument(
"--do-sample",
action="store_true",
default=False,
help="Enable sampling for text generation (default: greedy).",
)
parser.add_argument(
"--text-temperature",
type=float,
default=0.3,
help="Temperature for text generation sampling (default: 0.3).",
)

args = parser.parse_args()
return args
Expand All @@ -108,7 +126,6 @@ def main():
model_name = args.model
prompts: list[OmniPromptType] = []
try:
# Preferred: load from txt file (one prompt per line)
if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
Expand All @@ -121,22 +138,22 @@ def main():
raise

if not prompts:
# Default prompt for text2img test if none provided
prompts = ["A cute cat"]
print(f"[Info] No prompts provided, using default: {prompts}")
omni_outputs = []

from PIL import Image

from vllm_omni.entrypoints.omni import Omni

omni_kwargs = {}
stage_configs_path = args.stage_configs_path
is_single_stage = stage_configs_path and "single_stage" in stage_configs_path
if args.think and stage_configs_path is None:
stage_configs_path = "vllm_omni/model_executor/stage_configs/bagel_think.yaml"
print(f"[Info] Think mode enabled, using stage config: {stage_configs_path}")
if stage_configs_path:
omni_kwargs["stage_configs_path"] = stage_configs_path
is_single_stage = "single_stage" in stage_configs_path

omni_kwargs.update(
{
Expand Down Expand Up @@ -198,40 +215,61 @@ def main():
formatted_prompts.append(prompt_dict)

params_list = omni.default_sampling_params_list

# For single-stage DiT, think/text params go into the diffusion sampling params extra_args.
# For 2-stage, diffusion params are at index 1.
diffusion_params_idx = 0 if is_single_stage else (1 if len(params_list) > 1 else 0)
diffusion_params = params_list[diffusion_params_idx]

if args.modality in ("text2img", "img2img"):
if len(params_list) > 1:
diffusion_params = params_list[1]
diffusion_params.num_inference_steps = args.steps # type: ignore
diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
if args.seed is not None:
diffusion_params.seed = args.seed # type: ignore
extra = {
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
}
if args.cfg_interval is not None:
extra["cfg_interval"] = tuple(args.cfg_interval)
if args.cfg_renorm_type is not None:
extra["cfg_renorm_type"] = args.cfg_renorm_type
if args.cfg_renorm_min is not None:
extra["cfg_renorm_min"] = args.cfg_renorm_min
if args.negative_prompt is not None:
extra["negative_prompt"] = args.negative_prompt
diffusion_params.extra_args = extra # type: ignore
diffusion_params.num_inference_steps = args.steps # type: ignore
diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
if args.seed is not None:
diffusion_params.seed = args.seed # type: ignore

extra = getattr(diffusion_params, "extra_args", {}) or {}
extra["cfg_text_scale"] = args.cfg_text_scale
extra["cfg_img_scale"] = args.cfg_img_scale
if args.cfg_interval is not None:
extra["cfg_interval"] = tuple(args.cfg_interval)
if args.cfg_renorm_type is not None:
extra["cfg_renorm_type"] = args.cfg_renorm_type
if args.cfg_renorm_min is not None:
extra["cfg_renorm_min"] = args.cfg_renorm_min
if args.negative_prompt is not None:
extra["negative_prompt"] = args.negative_prompt

needs_text_gen = is_single_stage and (args.think or args.modality in ("text2text", "img2text"))
if needs_text_gen:
if args.think:
extra["think"] = True
extra["max_think_tokens"] = args.max_think_tokens
extra["do_sample"] = args.do_sample
extra["text_temperature"] = args.text_temperature
diffusion_params.extra_args = extra # type: ignore

omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))

img_idx = 0
for req_output in omni_outputs:
if args.think:
ro = getattr(req_output, "request_output", None)
if ro and getattr(ro, "outputs", None):
txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
if txt:
print(txt)
# 2-stage think mode: text output from thinker stage
ro = getattr(req_output, "request_output", None)
if ro and getattr(ro, "outputs", None):
txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
if txt:
if args.think:
print(f"[Think]\n{txt}")
else:
print(f"[Output] Text:\n{txt}")

images = getattr(req_output, "images", None)
# Single-stage DiT: text from custom_output
custom = getattr(req_output, "_custom_output", {}) or {}
if custom.get("think_text"):
print(f"[Think]\n{custom['think_text']}")
if custom.get("text_output"):
print(f"[Output] Text:\n{custom['text_output']}")

images = getattr(req_output, "images", None)
if not images:
continue

Expand All @@ -241,8 +279,6 @@ def main():
print(f"[Output] Saved image to {save_path}")
img_idx += 1

print(omni_outputs)


if __name__ == "__main__":
main()
113 changes: 110 additions & 3 deletions vllm_omni/diffusion/models/bagel/bagel_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def __init__(
config, parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.model"
)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
Comment thread
princepride marked this conversation as resolved.

# Initialize weights and apply final processing
self.post_init()
Expand All @@ -864,6 +865,12 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model = decoder

Expand Down Expand Up @@ -1207,7 +1214,7 @@ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen

text_ids = tokenizer.encode(prompt)
text_ids = tokenizer.encode(prompt, add_special_tokens=False)
text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]]
text_token_lens.append(len(text_ids))
packed_text_ids.extend(text_ids)
Expand Down Expand Up @@ -1619,10 +1626,110 @@ def _merge_naive_caches(caches: list) -> NaiveCache:
num_layers = len(caches[0].key_cache)
merged = NaiveCache(num_layers)
for layer_idx in range(num_layers):
merged.key_cache[layer_idx] = torch.cat([c.key_cache[layer_idx] for c in caches], dim=0)
merged.value_cache[layer_idx] = torch.cat([c.value_cache[layer_idx] for c in caches], dim=0)
key_parts = [c.key_cache[layer_idx] for c in caches if c.key_cache[layer_idx] is not None]
val_parts = [c.value_cache[layer_idx] for c in caches if c.value_cache[layer_idx] is not None]
merged.key_cache[layer_idx] = torch.cat(key_parts, dim=0) if key_parts else None
merged.value_cache[layer_idx] = torch.cat(val_parts, dim=0) if val_parts else None
return merged

def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
"""Prepare start tokens for autoregressive text generation.

Ported from the original BAGEL ``Bagel.prepare_start_tokens``.
"""
packed_start_tokens, packed_key_value_indexes = list(), list()
packed_query_position_ids = list()

curr = 0
for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
packed_start_tokens.append(new_token_ids["bos_token_id"])
packed_query_position_ids.append(curr_position_id)
curr += curr_kvlen

generation_input = {
"packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
"packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
}
return generation_input

@torch.no_grad()
def generate_text(
self,
past_key_values: NaiveCache,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
packed_start_tokens: torch.LongTensor,
packed_query_position_ids: torch.LongTensor,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
end_token_id: int | None = None,
):
"""Autoregressive text generation (ported from original BAGEL).

Decodes tokens one at a time, appending to ``past_key_values``
until ``max_length`` is reached or ``end_token_id`` is generated.
"""
step = 0
generated_sequence = []
curr_tokens = packed_start_tokens
while step < max_length:
generated_sequence.append(curr_tokens)
packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
query_lens = torch.ones_like(curr_tokens)
packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
0,
len(key_values_lens),
device=key_values_lens.device,
dtype=key_values_lens.dtype,
)

uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
for i in range(len(uppacked)):
uppacked[i] += i
packed_key_value_indexes = torch.cat(uppacked, dim=0)

output = self.language_model(
packed_query_sequence=packed_text_embedding,
query_lens=query_lens,
packed_query_position_ids=packed_query_position_ids,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=True,
is_causal=True,
mode="und",
)
past_key_values = output.past_key_values
packed_query_sequence = output.packed_query_sequence
pred_logits = self.language_model.lm_head(packed_query_sequence)

if do_sample:
probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
curr_tokens = torch.argmax(pred_logits, dim=-1)

uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
for i in range(len(uppacked)):
uppacked[i] = torch.cat(
[uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
)
packed_key_value_indexes = torch.cat(uppacked, dim=0)
key_values_lens = key_values_lens + 1
packed_query_position_ids = packed_query_position_ids + 1
step += 1

if end_token_id is not None and curr_tokens[0] == end_token_id:
break

output_device = generated_sequence[0].device
return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)

def generate_image(
self,
packed_text_ids: torch.LongTensor,
Expand Down
Loading
Loading