From a109b772ef6098e180e0c892496e9789fca6df21 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Thu, 15 Jan 2026 17:01:52 +0800 Subject: [PATCH 1/3] fix glm image Signed-off-by: JaredforReal --- .../diffusion/models/glm_image/__init__.py | 4 +- .../models/glm_image/pipeline_glm_image.py | 322 ++++++++++-------- vllm_omni/diffusion/registry.py | 1 + 3 files changed, 174 insertions(+), 153 deletions(-) diff --git a/vllm_omni/diffusion/models/glm_image/__init__.py b/vllm_omni/diffusion/models/glm_image/__init__.py index ac7a98fa743..fc8256d8de6 100644 --- a/vllm_omni/diffusion/models/glm_image/__init__.py +++ b/vllm_omni/diffusion/models/glm_image/__init__.py @@ -9,7 +9,7 @@ from vllm_omni.diffusion.models.glm_image.pipeline_glm_image import ( GlmImagePipeline, get_glm_image_post_process_func, - # get_glm_image_pre_process_func, + get_glm_image_pre_process_func, ) __all__ = [ @@ -17,5 +17,5 @@ "GlmImagePipeline", "GlmImageTransformer2DModel", "get_glm_image_post_process_func", - # "get_glm_image_pre_process_func", + "get_glm_image_pre_process_func", ] diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py index f582c3b9b69..53fea3d79e4 100644 --- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -17,7 +17,6 @@ import os import re from collections.abc import Iterable -from math import sqrt import numpy as np import PIL.Image @@ -56,6 +55,73 @@ logger = logging.getLogger(__name__) +def get_glm_image_pre_process_func(od_config: OmniDiffusionConfig): + """Get pre-processing function for GLM-Image pipeline. + + Pre-processes condition images before they are sent to the pipeline. + This is called by DiffusionEngine before batching requests. + """ + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + block_out_channels = vae_config.get("block_out_channels", [128, 256, 512, 512]) + vae_scale_factor = 2 ** (len(block_out_channels) - 1) + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + # GLM-Image uses patch_size=2 for transformer + patch_size = 2 + + def pre_process_func(requests: list[OmniDiffusionRequest]): + """Pre-process condition images for Image Edit mode.""" + for req in requests: + images = req.pil_image + if images is None: + # Text-to-image mode, no preprocessing needed + continue + + if not isinstance(images, list): + images = [images] + + preprocessed = [] + height, width = None, None + + for img in images: + if isinstance(img, PIL.Image.Image): + img_h, img_w = img.size[::-1] # PIL is (width, height) + else: + img_h, img_w = img.shape[:2] + + # Align to multiple of vae_scale_factor * patch_size + multiple_of = vae_scale_factor * patch_size + img_h = (img_h // multiple_of) * multiple_of + img_w = (img_w // multiple_of) * multiple_of + + processed = image_processor.preprocess(img, height=img_h, width=img_w) + preprocessed.append(processed) + + # Use first image dimensions as default + if height is None: + height, width = img_h, img_w + + # Store in request + req.preprocessed_image = preprocessed + req.prompt_image = images # Keep original PIL images + if req.height is None: + req.height = height + if req.width is None: + req.width = width + + return requests + + return pre_process_func + + def get_glm_image_post_process_func(od_config: OmniDiffusionConfig): """Get post-processing function for GLM-Image pipeline.""" model_name = od_config.model @@ -73,7 +139,7 @@ def get_glm_image_post_process_func(od_config: OmniDiffusionConfig): image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) def post_process_func(images: PIL.Image.Image): - return images + return image_processor.postprocess(images, output_type="pil") return post_process_func @@ -291,45 +357,41 @@ def check_inputs( # ==================== AR Stage Methods ==================== @staticmethod - def _build_image_grid_thw( - token_h: int, - token_w: int, - prev_token_h: int, - prev_token_w: int, - existing_grid: torch.Tensor | None = None, - device: torch.device | None = None, - ) -> torch.Tensor: - """Build image grid tensor for AR model.""" - if existing_grid is None or existing_grid.numel() == 0: - return torch.tensor( - [ - [1, token_h, token_w], - [1, prev_token_h, prev_token_w], - ], - device=device, - ) - else: - return torch.cat( - [existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], - dim=0, - ) + def _compute_generation_params( + image_grid_thw: torch.Tensor, + is_text_to_image: bool, + ) -> tuple[int, int, int, int]: + """ + Compute AR generation parameters from image grid. - @staticmethod - def _calculate_ar_generation_params( - token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool - ) -> tuple[int, int]: - """Calculate AR generation parameters.""" - large_image_tokens = token_h * token_w - small_image_tokens = prev_token_h * prev_token_w + Args: + image_grid_thw: Image grid tensor of shape [N, 3] where each row is [t, h, w] + is_text_to_image: Whether this is text-to-image (vs image-to-image) - if is_text_to_image: - max_new_tokens = small_image_tokens + large_image_tokens + 1 - large_image_start_offset = small_image_tokens - else: - max_new_tokens = large_image_tokens + 1 + Returns: + Tuple of (max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w) + """ + grid_sizes = [] + grid_hw = [] + + for i in range(image_grid_thw.shape[0]): + t, h, w = image_grid_thw[i].tolist() + grid_sizes.append(int(h * w)) + grid_hw.append((int(h), int(w))) + + if not is_text_to_image: + # Image-to-image: only generate target image tokens + max_new_tokens = grid_sizes[-1] + 1 large_image_start_offset = 0 + target_grid_h, target_grid_w = grid_hw[-1] + else: + # Text-to-image: generate both small preview and large target + total_tokens = sum(grid_sizes) + max_new_tokens = total_tokens + 1 + large_image_start_offset = sum(grid_sizes[1:]) + target_grid_h, target_grid_w = grid_hw[0] - return max_new_tokens, large_image_start_offset + return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w @staticmethod def _extract_large_image_tokens( @@ -351,28 +413,6 @@ def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> token_ids = token_ids.view(1, -1) return token_ids - @staticmethod - def _build_prompt_with_shape( - prompt: str, - height: int, - width: int, - is_text_to_image: bool, - factor: int = 32, - ) -> tuple[str, int, int, int, int]: - """Build prompt with shape information for AR model.""" - token_h = height // factor - token_w = width // factor - ratio = token_h / token_w - prev_token_h = int(sqrt(ratio) * (factor // 2)) - prev_token_w = int(sqrt(1 / ratio) * (factor // 2)) - - if is_text_to_image: - expanded_prompt = f"{prompt}{token_h} {token_w}{prev_token_h} {prev_token_w}" - else: - expanded_prompt = f"{prompt}{token_h} {token_w}" - - return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w - @torch.inference_mode() def generate_prior_tokens( self, @@ -381,7 +421,7 @@ def generate_prior_tokens( width: int, image: list[PIL.Image.Image] | None = None, factor: int = 32, - ) -> tuple[torch.Tensor, torch.Tensor | None, int, int]: + ) -> tuple[torch.Tensor, list[torch.Tensor] | None]: """ Generate prior tokens using the AR model. @@ -393,63 +433,67 @@ def generate_prior_tokens( factor: Token factor (default 32) Returns: - Tuple of (prior_token_ids, prior_token_image_ids, pixel_height, pixel_width) + Tuple of (prior_token_ids, prior_token_image_ids) + prior_token_image_ids is a list of tensors, one per condition image """ device = self.vision_language_encoder.device height = (height // factor) * factor width = (width // factor) * factor is_text_to_image = image is None or len(image) == 0 - expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape( - prompt, height, width, is_text_to_image - ) - # Build message content content = [] if image is not None: for img in image: content.append({"type": "image", "image": img}) - content.append({"type": "text", "text": expanded_prompt}) + content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] - # Apply chat template + # Apply chat template - processor will handle target dimensions and build grid inputs = self.processor.apply_chat_template( messages, - add_generation_prompt=True, tokenize=True, + target_h=height, + target_w=width, return_dict=True, return_tensors="pt", - ) + ).to(device) - # Build image grid - existing_grid = inputs.get("image_grid_thw") - inputs["image_grid_thw"] = self._build_image_grid_thw( - token_h, - token_w, - prev_h, - prev_w, - existing_grid=existing_grid if not is_text_to_image else None, - device=device, - ) + image_grid_thw = inputs.get("image_grid_thw") - max_new_tokens, large_image_offset = self._calculate_ar_generation_params( - token_h, token_w, prev_h, prev_w, is_text_to_image + # Compute generation parameters from the full grid + max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params( + image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image ) - large_image_tokens = token_h * token_w - - inputs = inputs.to(device) - input_length = inputs["input_ids"].shape[-1] # Process condition images if provided + # Use image_grid_thw[:-1] to exclude the target image grid (last entry) prior_token_image_ids = None - if image is not None and existing_grid is not None: + if image is not None and image_grid_thw is not None and len(image_grid_thw) > 1: + # Get features only for condition images (exclude target image grid) + condition_grid = image_grid_thw[:-1] prior_token_image_embed = self.vision_language_encoder.get_image_features( - inputs["pixel_values"], existing_grid + inputs["pixel_values"], condition_grid ) prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) - prior_token_image_ids = self.vision_language_encoder.get_image_tokens( - prior_token_image_embed, existing_grid + flat_prior_token_image_ids = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, condition_grid ) + # Split by image grid sizes and convert to list + split_sizes = (condition_grid.prod(dim=-1)).tolist() + prior_token_image_ids_list = torch.split(flat_prior_token_image_ids, split_sizes, dim=0) + # Convert to list with upsampling + prior_token_image_ids = [] + for i, token_ids in enumerate(prior_token_image_ids_list): + grid_t, grid_h, grid_w = condition_grid[i].tolist() + token_ids = token_ids.view(1, -1) + # Upsample 2x (from d32 to d64) + token_ids_2d = token_ids.view(1, 1, grid_h, grid_w) + token_ids_upsampled = torch.nn.functional.interpolate( + token_ids_2d.float(), scale_factor=2, mode="nearest" + ).to(dtype=torch.long) + token_ids_upsampled = token_ids_upsampled.view(1, -1) + prior_token_image_ids.append(token_ids_upsampled) # Generate with AR model outputs = self.vision_language_encoder.generate( @@ -459,8 +503,9 @@ def generate_prior_tokens( ) # Extract and upsample tokens + large_image_tokens = token_h * token_w prior_token_ids_d32 = self._extract_large_image_tokens( - outputs, input_length, large_image_offset, large_image_tokens + outputs, inputs["input_ids"].shape[-1], large_image_offset, large_image_tokens ) prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) @@ -634,7 +679,7 @@ def diffuse( timestep=timestep, target_size=target_size, crop_coords=crop_coords, - kv_caches=kv_caches, + kv_cache=kv_caches, return_dict=False, )[0].float() else: @@ -647,7 +692,7 @@ def diffuse( timestep=timestep, target_size=target_size, crop_coords=crop_coords, - kv_caches=kv_caches, + kv_cache=kv_caches, return_dict=False, )[0].float() @@ -690,7 +735,7 @@ def diffuse( timestep=timestep, target_size=target_size, crop_coords=crop_coords, - kv_caches=kv_caches, + kv_cache=kv_caches, return_dict=False, )[0].float() @@ -763,54 +808,12 @@ def _prepare_condition_image_kv_cache( timestep=torch.zeros((1,), device=self.device), target_size=torch.tensor([condition_image.shape[-2:]], device=self.device, dtype=prompt_embeds.dtype), crop_coords=torch.zeros((1, 2), device=self.device, dtype=prompt_embeds.dtype), - kv_caches=kv_caches, + kv_cache=kv_caches, return_dict=False, ) return kv_caches - def _preprocess_condition_images( - self, - images: list[PIL.Image.Image] | PIL.Image.Image | None, - ) -> tuple[list[torch.Tensor] | None, int | None, int | None]: - """ - Preprocess condition images for Image Edit mode. - - Args: - images: Input images (PIL or list of PIL) - - Returns: - Tuple of (preprocessed_images, height, width) - """ - if images is None: - return None, None, None - - if not isinstance(images, list): - images = [images] - - preprocessed = [] - height, width = None, None - - for img in images: - if isinstance(img, PIL.Image.Image): - img_h, img_w = img.size[::-1] - else: - img_h, img_w = img.shape[:2] - - # Align to multiple of vae_scale_factor * patch_size - multiple_of = self.vae_scale_factor * self._patch_size - img_h = (img_h // multiple_of) * multiple_of - img_w = (img_w // multiple_of) * multiple_of - - processed = self.image_processor.preprocess(img, height=img_h, width=img_w) - preprocessed.append(processed) - - # Use first image dimensions as default - if height is None: - height, width = img_h, img_w - - return preprocessed, height, width - @torch.inference_mode() def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: """ @@ -830,12 +833,12 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: prompt_embeds = req.prompt_embeds if isinstance(req.prompt_embeds, torch.Tensor) else None # Get condition images for Image Edit mode - condition_images = req.pil_image - if condition_images is not None and not isinstance(condition_images, list): - condition_images = [condition_images] + # Use pre-processed images from pre_process_func + preprocessed_images = req.preprocessed_image + condition_images = req.prompt_image if hasattr(req, "prompt_image") else req.pil_image + img_height = req.height + img_width = req.width - # Preprocess condition images and get dimensions - preprocessed_images, img_height, img_width = self._preprocess_condition_images(condition_images) is_image_edit = preprocessed_images is not None # Use image dimensions as default if available @@ -855,14 +858,32 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: if req.seed is not None: generator = torch.Generator(device=self.device).manual_seed(req.seed) - # 1. Generate prior tokens with AR model - logger.info("Generating prior tokens with AR model...") - prior_token_id, prior_token_image_ids = self.generate_prior_tokens( - prompt=prompt, - image=condition_images, - height=height, - width=width, - ) + # 1. Get prior tokens - either from external source (multistage) or generate internally + # Check if prior_token_ids are provided externally (from AR stage in multistage mode) + external_prior_tokens = req.extra.get("prior_token_ids") if req.extra else None + external_prior_image_ids = req.extra.get("prior_token_image_ids") if req.extra else None + + if external_prior_tokens is not None: + # Multistage mode: use externally provided prior tokens from vLLM AR stage + logger.info("Using externally provided prior tokens from AR stage...") + prior_token_id = external_prior_tokens + if isinstance(prior_token_id, list): + prior_token_id = torch.tensor(prior_token_id, dtype=torch.long, device=self.device) + elif isinstance(prior_token_id, torch.Tensor): + prior_token_id = prior_token_id.to(device=self.device, dtype=torch.long) + # Ensure shape is [1, num_tokens] for batch processing + if prior_token_id.dim() == 1: + prior_token_id = prior_token_id.unsqueeze(0) + prior_token_image_ids = external_prior_image_ids + else: + # Single-stage mode: generate prior tokens with internal AR model + logger.info("Generating prior tokens with AR model...") + prior_token_id, prior_token_image_ids = self.generate_prior_tokens( + prompt=prompt, + image=condition_images, + height=height, + width=width, + ) # 2. Encode prompt for glyph embeddings logger.info("Encoding prompt...") @@ -951,8 +972,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: latents = latents * latents_std + latents_mean image = self.vae.decode(latents, return_dict=False, generator=generator)[0] - # 9. Post-process - image = self.image_processor.postprocess(image, output_type="pil")[0] + # 9. Leave post-process to vllm-omni pipeline return DiffusionOutput(output=image) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index e566ca66cfa..94cf889210f 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -133,6 +133,7 @@ def initialize_model( # arch: pre_process_func # `pre_process_func` function must be placed in {mod_folder}/{mod_relname}.py, # where mod_folder and mod_relname are defined and mapped using `_DIFFUSION_MODELS` via the `arch` key + "GlmImagePipeline": "get_glm_image_pre_process_func", "QwenImageEditPipeline": "get_qwen_image_edit_pre_process_func", "QwenImageEditPlusPipeline": "get_qwen_image_edit_plus_pre_process_func", "LongCatImageEditPipeline": "get_longcat_image_edit_pre_process_func", From 9b20c8eb0f96bd1b5abc9c77ec0bfd63cfbb8088 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Thu, 15 Jan 2026 17:34:12 +0800 Subject: [PATCH 2/3] accept some reviews Signed-off-by: JaredforReal --- .../diffusion/models/glm_image/pipeline_glm_image.py | 10 +++------- vllm_omni/model_executor/models/glm_image/glm_image.py | 0 2 files changed, 3 insertions(+), 7 deletions(-) create mode 100644 vllm_omni/model_executor/models/glm_image/glm_image.py diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py index 53fea3d79e4..9a03a934983 100644 --- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -138,7 +138,7 @@ def get_glm_image_post_process_func(od_config: OmniDiffusionConfig): image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - def post_process_func(images: PIL.Image.Image): + def post_process_func(images: torch.Tensor) -> list[PIL.Image.Image]: return image_processor.postprocess(images, output_type="pil") return post_process_func @@ -488,11 +488,7 @@ def generate_prior_tokens( grid_t, grid_h, grid_w = condition_grid[i].tolist() token_ids = token_ids.view(1, -1) # Upsample 2x (from d32 to d64) - token_ids_2d = token_ids.view(1, 1, grid_h, grid_w) - token_ids_upsampled = torch.nn.functional.interpolate( - token_ids_2d.float(), scale_factor=2, mode="nearest" - ).to(dtype=torch.long) - token_ids_upsampled = token_ids_upsampled.view(1, -1) + token_ids_upsampled = self._upsample_token_ids(token_ids, grid_h, grid_w) prior_token_image_ids.append(token_ids_upsampled) # Generate with AR model @@ -835,7 +831,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Get condition images for Image Edit mode # Use pre-processed images from pre_process_func preprocessed_images = req.preprocessed_image - condition_images = req.prompt_image if hasattr(req, "prompt_image") else req.pil_image + condition_images = getattr(req, "prompt_image", None) img_height = req.height img_width = req.width diff --git a/vllm_omni/model_executor/models/glm_image/glm_image.py b/vllm_omni/model_executor/models/glm_image/glm_image.py new file mode 100644 index 00000000000..e69de29bb2d From fd10877159a17dddde9003743e5e5789b403b389 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Thu, 15 Jan 2026 17:40:50 +0800 Subject: [PATCH 3/3] remove empty file Signed-off-by: JaredforReal --- vllm_omni/model_executor/models/glm_image/glm_image.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 vllm_omni/model_executor/models/glm_image/glm_image.py diff --git a/vllm_omni/model_executor/models/glm_image/glm_image.py b/vllm_omni/model_executor/models/glm_image/glm_image.py deleted file mode 100644 index e69de29bb2d..00000000000