diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index 472d748d1e6..ed5fa57e8d6 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -97,6 +97,24 @@ def parse_args(): default=False, help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.", ) + parser.add_argument( + "--max-think-tokens", + type=int, + default=1000, + help="Maximum number of tokens for thinking text generation (default: 1000).", + ) + parser.add_argument( + "--do-sample", + action="store_true", + default=False, + help="Enable sampling for text generation (default: greedy).", + ) + parser.add_argument( + "--text-temperature", + type=float, + default=0.3, + help="Temperature for text generation sampling (default: 0.3).", + ) args = parser.parse_args() return args @@ -108,7 +126,6 @@ def main(): model_name = args.model prompts: list[OmniPromptType] = [] try: - # Preferred: load from txt file (one prompt per line) if getattr(args, "txt_prompts", None) and args.prompt_type == "text": with open(args.txt_prompts, encoding="utf-8") as f: lines = [ln.strip() for ln in f.readlines()] @@ -121,10 +138,8 @@ def main(): raise if not prompts: - # Default prompt for text2img test if none provided prompts = ["A cute cat"] print(f"[Info] No prompts provided, using default: {prompts}") - omni_outputs = [] from PIL import Image @@ -132,11 +147,13 @@ def main(): omni_kwargs = {} stage_configs_path = args.stage_configs_path + is_single_stage = stage_configs_path and "single_stage" in stage_configs_path if args.think and stage_configs_path is None: stage_configs_path = "vllm_omni/model_executor/stage_configs/bagel_think.yaml" print(f"[Info] Think mode enabled, using stage config: {stage_configs_path}") if stage_configs_path: omni_kwargs["stage_configs_path"] = stage_configs_path + is_single_stage = "single_stage" in stage_configs_path omni_kwargs.update( { @@ -198,40 +215,61 @@ def main(): formatted_prompts.append(prompt_dict) params_list = omni.default_sampling_params_list + + # For single-stage DiT, think/text params go into the diffusion sampling params extra_args. + # For 2-stage, diffusion params are at index 1. + diffusion_params_idx = 0 if is_single_stage else (1 if len(params_list) > 1 else 0) + diffusion_params = params_list[diffusion_params_idx] + if args.modality in ("text2img", "img2img"): - if len(params_list) > 1: - diffusion_params = params_list[1] - diffusion_params.num_inference_steps = args.steps # type: ignore - diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore - if args.seed is not None: - diffusion_params.seed = args.seed # type: ignore - extra = { - "cfg_text_scale": args.cfg_text_scale, - "cfg_img_scale": args.cfg_img_scale, - } - if args.cfg_interval is not None: - extra["cfg_interval"] = tuple(args.cfg_interval) - if args.cfg_renorm_type is not None: - extra["cfg_renorm_type"] = args.cfg_renorm_type - if args.cfg_renorm_min is not None: - extra["cfg_renorm_min"] = args.cfg_renorm_min - if args.negative_prompt is not None: - extra["negative_prompt"] = args.negative_prompt - diffusion_params.extra_args = extra # type: ignore + diffusion_params.num_inference_steps = args.steps # type: ignore + diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore + if args.seed is not None: + diffusion_params.seed = args.seed # type: ignore + + extra = getattr(diffusion_params, "extra_args", {}) or {} + extra["cfg_text_scale"] = args.cfg_text_scale + extra["cfg_img_scale"] = args.cfg_img_scale + if args.cfg_interval is not None: + extra["cfg_interval"] = tuple(args.cfg_interval) + if args.cfg_renorm_type is not None: + extra["cfg_renorm_type"] = args.cfg_renorm_type + if args.cfg_renorm_min is not None: + extra["cfg_renorm_min"] = args.cfg_renorm_min + if args.negative_prompt is not None: + extra["negative_prompt"] = args.negative_prompt + + needs_text_gen = is_single_stage and (args.think or args.modality in ("text2text", "img2text")) + if needs_text_gen: + if args.think: + extra["think"] = True + extra["max_think_tokens"] = args.max_think_tokens + extra["do_sample"] = args.do_sample + extra["text_temperature"] = args.text_temperature + diffusion_params.extra_args = extra # type: ignore omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list)) img_idx = 0 for req_output in omni_outputs: - if args.think: - ro = getattr(req_output, "request_output", None) - if ro and getattr(ro, "outputs", None): - txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs) - if txt: - print(txt) + # 2-stage think mode: text output from thinker stage + ro = getattr(req_output, "request_output", None) + if ro and getattr(ro, "outputs", None): + txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs) + if txt: + if args.think: + print(f"[Think]\n{txt}") + else: + print(f"[Output] Text:\n{txt}") - images = getattr(req_output, "images", None) + # Single-stage DiT: text from custom_output + custom = getattr(req_output, "_custom_output", {}) or {} + if custom.get("think_text"): + print(f"[Think]\n{custom['think_text']}") + if custom.get("text_output"): + print(f"[Output] Text:\n{custom['text_output']}") + images = getattr(req_output, "images", None) if not images: continue @@ -241,8 +279,6 @@ def main(): print(f"[Output] Saved image to {save_path}") img_idx += 1 - print(omni_outputs) - if __name__ == "__main__": main() diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index f8480775687..d1254f84566 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -854,6 +854,7 @@ def __init__( config, parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.model" ) self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @@ -864,6 +865,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + def set_decoder(self, decoder): self.model = decoder @@ -1207,7 +1214,7 @@ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen - text_ids = tokenizer.encode(prompt) + text_ids = tokenizer.encode(prompt, add_special_tokens=False) text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]] text_token_lens.append(len(text_ids)) packed_text_ids.extend(text_ids) @@ -1619,10 +1626,110 @@ def _merge_naive_caches(caches: list) -> NaiveCache: num_layers = len(caches[0].key_cache) merged = NaiveCache(num_layers) for layer_idx in range(num_layers): - merged.key_cache[layer_idx] = torch.cat([c.key_cache[layer_idx] for c in caches], dim=0) - merged.value_cache[layer_idx] = torch.cat([c.value_cache[layer_idx] for c in caches], dim=0) + key_parts = [c.key_cache[layer_idx] for c in caches if c.key_cache[layer_idx] is not None] + val_parts = [c.value_cache[layer_idx] for c in caches if c.value_cache[layer_idx] is not None] + merged.key_cache[layer_idx] = torch.cat(key_parts, dim=0) if key_parts else None + merged.value_cache[layer_idx] = torch.cat(val_parts, dim=0) if val_parts else None return merged + def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): + """Prepare start tokens for autoregressive text generation. + + Ported from the original BAGEL ``Bagel.prepare_start_tokens``. + """ + packed_start_tokens, packed_key_value_indexes = list(), list() + packed_query_position_ids = list() + + curr = 0 + for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + packed_start_tokens.append(new_token_ids["bos_token_id"]) + packed_query_position_ids.append(curr_position_id) + curr += curr_kvlen + + generation_input = { + "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long), + "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + } + return generation_input + + @torch.no_grad() + def generate_text( + self, + past_key_values: NaiveCache, + packed_key_value_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + packed_start_tokens: torch.LongTensor, + packed_query_position_ids: torch.LongTensor, + max_length: int, + do_sample: bool = False, + temperature: float = 1.0, + end_token_id: int | None = None, + ): + """Autoregressive text generation (ported from original BAGEL). + + Decodes tokens one at a time, appending to ``past_key_values`` + until ``max_length`` is reached or ``end_token_id`` is generated. + """ + step = 0 + generated_sequence = [] + curr_tokens = packed_start_tokens + while step < max_length: + generated_sequence.append(curr_tokens) + packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens) + query_lens = torch.ones_like(curr_tokens) + packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange( + 0, + len(key_values_lens), + device=key_values_lens.device, + dtype=key_values_lens.dtype, + ) + + uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) + for i in range(len(uppacked)): + uppacked[i] += i + packed_key_value_indexes = torch.cat(uppacked, dim=0) + + output = self.language_model( + packed_query_sequence=packed_text_embedding, + query_lens=query_lens, + packed_query_position_ids=packed_query_position_ids, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=True, + is_causal=True, + mode="und", + ) + past_key_values = output.past_key_values + packed_query_sequence = output.packed_query_sequence + pred_logits = self.language_model.lm_head(packed_query_sequence) + + if do_sample: + probs = nn.functional.softmax(pred_logits / temperature, dim=-1) + curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + curr_tokens = torch.argmax(pred_logits, dim=-1) + + uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) + for i in range(len(uppacked)): + uppacked[i] = torch.cat( + [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0 + ) + packed_key_value_indexes = torch.cat(uppacked, dim=0) + key_values_lens = key_values_lens + 1 + packed_query_position_ids = packed_query_position_ids + 1 + step += 1 + + if end_token_id is not None and curr_tokens[0] == end_token_id: + break + + output_device = generated_sequence[0].device + return torch.stack([i.to(output_device) for i in generated_sequence], dim=0) + def generate_image( self, packed_text_ids: torch.LongTensor, diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 13d0cc2093b..72e53e7f48f 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -495,11 +495,15 @@ def vae_transforms(img): cfg_text_context = deepcopy(gen_context) + # Strip <|im_start|>/<|im_end|> wrappers that end2end.py may have + # already added, so prepare_prompts doesn't double-add bos/eos. + clean_prompt = prompt.removeprefix("<|im_start|>").removesuffix("<|im_end|>") + # Update gen_context with text prompt generation_input, newlens, new_rope = self.bagel.prepare_prompts( curr_kvlens=gen_context["kv_lens"], curr_rope=gen_context["ropes"], - prompts=[prompt], + prompts=[clean_prompt], tokenizer=self.tokenizer, new_token_ids=self.new_token_ids, ) @@ -527,34 +531,37 @@ def vae_transforms(img): gen_context["kv_lens"] = newlens gen_context["ropes"] = new_rope - # cfg_text_context: update with negative prompt (no text condition) + # cfg_text_context: update with negative prompt (no text condition). + # When empty, keep cfg_text_context as-is (kv_lens=0) to match + # original BAGEL; _merge_naive_caches handles None KV entries. neg_prompt = extra_args.get("negative_prompt", "") - neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts( - curr_kvlens=cfg_text_context["kv_lens"], - curr_rope=cfg_text_context["ropes"], - prompts=[neg_prompt], - tokenizer=self.tokenizer, - new_token_ids=self.new_token_ids, - ) - for k, v in neg_input.items(): - if torch.is_tensor(v): - neg_input[k] = v.to(self.device) - with torch.autocast( - device_type=self.device.type, - enabled=self.device.type != "cpu", - dtype=self.od_config.dtype, - ): - cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text( - cfg_text_context["past_key_values"], **neg_input + if neg_prompt: + neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts( + curr_kvlens=cfg_text_context["kv_lens"], + curr_rope=cfg_text_context["ropes"], + prompts=[neg_prompt], + tokenizer=self.tokenizer, + new_token_ids=self.new_token_ids, ) - cfg_text_context["kv_lens"] = neg_newlens - cfg_text_context["ropes"] = neg_rope + for k, v in neg_input.items(): + if torch.is_tensor(v): + neg_input[k] = v.to(self.device) + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text( + cfg_text_context["past_key_values"], **neg_input + ) + cfg_text_context["kv_lens"] = neg_newlens + cfg_text_context["ropes"] = neg_rope # cfg_img_context: update with text prompt (no image condition) cfg_img_generation_input, cfg_img_newlens, cfg_img_new_rope = self.bagel.prepare_prompts( curr_kvlens=cfg_img_context["kv_lens"], curr_rope=cfg_img_context["ropes"], - prompts=[prompt], + prompts=[clean_prompt], tokenizer=self.tokenizer, new_token_ids=self.new_token_ids, ) @@ -572,6 +579,96 @@ def vae_transforms(img): cfg_img_context["kv_lens"] = cfg_img_newlens cfg_img_context["ropes"] = cfg_img_new_rope + # ---- Detect output modality and think mode ---- + modalities = first_prompt.get("modalities", []) if isinstance(first_prompt, dict) else [] + is_text_output = "text" in modalities + think_enabled = extra_args.get("think", False) + think_text = None + + if think_enabled and injected_kv is None: + max_think_tokens = int(extra_args.get("max_think_tokens", 1000)) + do_sample = bool(extra_args.get("do_sample", False)) + text_temperature = float(extra_args.get("text_temperature", 0.3)) + + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + start_input = self.bagel.prepare_start_tokens( + gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids + ) + for k, v in start_input.items(): + if torch.is_tensor(v): + start_input[k] = v.to(self.device) + + gen_ctx_copy = deepcopy(gen_context) + token_ids = self.bagel.generate_text( + past_key_values=gen_ctx_copy["past_key_values"], + max_length=max_think_tokens, + do_sample=do_sample, + temperature=text_temperature, + end_token_id=self.new_token_ids["eos_token_id"], + **start_input, + ) + # token_ids shape: (seq_len, batch=1) + decoded = self.tokenizer.decode(token_ids[:, 0].tolist()) + # Strip chat markers to get clean text + think_text = decoded.split("<|im_end|>")[0] + if "<|im_start|>" in think_text: + think_text = think_text.split("<|im_start|>")[-1] + logger.info("Think mode generated %d tokens", token_ids.shape[0]) + + if not is_text_output: + # Use the autoregressive KV cache from think generation + # directly, instead of decode→re-encode which adds extra + # bos/eos and may alter tokenization. + num_think_tokens = token_ids.shape[0] + gen_context["past_key_values"] = gen_ctx_copy["past_key_values"] + gen_context["kv_lens"] = [kl + num_think_tokens for kl in gen_context["kv_lens"]] + gen_context["ropes"] = [r + num_think_tokens for r in gen_context["ropes"]] + + # ---- Text-only output (text2text / img2text) ---- + if is_text_output and injected_kv is None: + if think_text is not None: + # Think mode already generated the text (including reasoning) + text_output = think_text + else: + max_text_tokens = int(extra_args.get("max_think_tokens", 500)) + do_sample = bool(extra_args.get("do_sample", False)) + text_temperature = float(extra_args.get("text_temperature", 0.3)) + + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + start_input = self.bagel.prepare_start_tokens( + gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids + ) + for k, v in start_input.items(): + if torch.is_tensor(v): + start_input[k] = v.to(self.device) + token_ids = self.bagel.generate_text( + past_key_values=gen_context["past_key_values"], + max_length=max_text_tokens, + do_sample=do_sample, + temperature=text_temperature, + end_token_id=self.new_token_ids["eos_token_id"], + **start_input, + ) + decoded = self.tokenizer.decode(token_ids[:, 0].tolist()) + text_output = decoded.split("<|im_end|>")[0] + if "<|im_start|>" in text_output: + text_output = text_output.split("<|im_start|>")[-1] + + return DiffusionOutput( + output=text_output, + custom_output={"text_output": text_output}, + stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, + ) + + # ---- Image generation (text2img / img2img) ---- if req.sampling_params.seed is not None: torch.manual_seed(req.sampling_params.seed) if self.device.type == "cuda": @@ -676,12 +773,17 @@ def vae_transforms(img): if trajectory_log_probs: trajectory_log_probs_stacked = torch.stack(trajectory_log_probs) + custom = {} + if think_text is not None: + custom["think_text"] = think_text + return DiffusionOutput( output=img, trajectory_latents=trajectory_latents_stacked, trajectory_timesteps=trajectory_timesteps_stacked, trajectory_log_probs=trajectory_log_probs_stacked, trajectory_decoded=trajectory_decoded, + custom_output=custom, stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, )