From 54896c9acc7139a4ca823508b8f19ca90e397cb0 Mon Sep 17 00:00:00 2001 From: Zeel Date: Fri, 17 Apr 2026 11:54:27 -0400 Subject: [PATCH] Optimize GLM-Image AR token upsampling and add profiling/tests Signed-off-by: Zeel --- .../test_glm_image_stage_input_processors.py | 179 ++++++++++++++++++ .../models/glm_image/glm_image_ar.py | 18 +- .../stage_input_processors/glm_image.py | 10 +- 3 files changed, 194 insertions(+), 13 deletions(-) create mode 100644 tests/model_executor/stage_input_processors/test_glm_image_stage_input_processors.py diff --git a/tests/model_executor/stage_input_processors/test_glm_image_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_glm_image_stage_input_processors.py new file mode 100644 index 0000000000..fccc8c3b42 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_glm_image_stage_input_processors.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors.glm_image import ( + _parse_generated_tokens, + _upsample_token_ids, + ar2diffusion, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _ar_output( + token_ids: list[int], + *, + multimodal_output: dict | None = None, +): + return SimpleNamespace( + outputs=[SimpleNamespace(token_ids=token_ids)], + multimodal_output=multimodal_output, + ) + + +def test_upsample_token_ids_matches_nearest_neighbor_layout(): + token_ids = torch.tensor([1, 2, 3, 4], dtype=torch.long) + + upsampled = _upsample_token_ids(token_ids, token_h=2, token_w=2) + + expected = torch.tensor( + [ + 1, + 1, + 2, + 2, + 1, + 1, + 2, + 2, + 3, + 3, + 4, + 4, + 3, + 3, + 4, + 4, + ], + dtype=torch.long, + ) + torch.testing.assert_close(upsampled, expected) + + +def test_ar2diffusion_builds_upsampled_prior_tokens_for_t2i(): + stage_list = [ + SimpleNamespace( + engine_outputs=[ + _ar_output([10, 11, 12, 13, 14], multimodal_output=None), + ] + ) + ] + + outputs = ar2diffusion( + stage_list=stage_list, + engine_input_source=[0], + prompt=[{"prompt": "hello", "height": 64, "width": 64}], + ) + + assert len(outputs) == 1 + assert outputs[0]["prompt"] == "hello" + assert outputs[0]["height"] == 64 + assert outputs[0]["width"] == 64 + torch.testing.assert_close( + outputs[0]["extra"]["prior_token_ids"], + torch.tensor([11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14], dtype=torch.long), + ) + + +def test_ar2diffusion_normalizes_serialized_prior_token_image_ids(): + stage_list = [ + SimpleNamespace( + engine_outputs=[ + _ar_output( + [21, 22, 23, 24], + multimodal_output={"prior_token_image_ids": [[101, 102, 103, 104]]}, + ), + ] + ) + ] + + outputs = ar2diffusion( + stage_list=stage_list, + engine_input_source=[0], + prompt=[ + { + "prompt": "edit", + "height": 64, + "width": 64, + "multi_modal_data": {"image": object()}, + } + ], + requires_multimodal_data=True, + ) + + prior_image_ids = outputs[0]["extra"]["prior_token_image_ids"] + assert isinstance(prior_image_ids, list) + assert len(prior_image_ids) == 1 + assert isinstance(prior_image_ids[0], torch.Tensor) + torch.testing.assert_close( + prior_image_ids[0], + torch.tensor([101, 102, 103, 104], dtype=torch.long), + ) + assert "pil_image" in outputs[0] + + +def test_ar2diffusion_uses_i2i_large_tokens_without_preview_prefix(): + stage_list = [ + SimpleNamespace( + engine_outputs=[ + _ar_output( + [31, 32, 33, 34, 16385], + multimodal_output={"prior_token_image_ids": [torch.tensor([201, 202, 203, 204], dtype=torch.long)]}, + ), + ] + ) + ] + + outputs = ar2diffusion( + stage_list=stage_list, + engine_input_source=[0], + prompt=[{"prompt": "edit", "height": 64, "width": 64}], + ) + + torch.testing.assert_close( + outputs[0]["extra"]["prior_token_ids"], + torch.tensor([31, 31, 32, 32, 31, 31, 32, 32, 33, 33, 34, 34, 33, 33, 34, 34], dtype=torch.long), + ) + torch.testing.assert_close( + outputs[0]["extra"]["prior_token_image_ids"][0], + torch.tensor([201, 202, 203, 204], dtype=torch.long), + ) + + +def test_ar2diffusion_reads_prior_token_image_ids_from_completion_output_fallback(): + output = SimpleNamespace( + token_ids=[41, 42, 43, 44], + multimodal_output={"prior_token_image_ids": [torch.tensor([301, 302, 303, 304], dtype=torch.long)]}, + ) + stage_list = [SimpleNamespace(engine_outputs=[SimpleNamespace(outputs=[output], multimodal_output=None)])] + + outputs = ar2diffusion( + stage_list=stage_list, + engine_input_source=[0], + prompt=[{"prompt": "fallback", "height": 64, "width": 64}], + ) + + torch.testing.assert_close( + outputs[0]["extra"]["prior_token_image_ids"][0], + torch.tensor([301, 302, 303, 304], dtype=torch.long), + ) + + +def test_parse_generated_tokens_adjusts_grid_for_truncated_output(): + prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens( + [51, 52, 53, 54], + height=128, + width=128, + ) + + assert pixel_h == 64 + assert pixel_w == 64 + torch.testing.assert_close( + prior_token_ids, + torch.tensor([51, 51, 52, 52, 51, 51, 52, 52, 53, 53, 54, 54, 53, 53, 54, 54], dtype=torch.long), + ) diff --git a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py index 31eed9b2cb..84c8b2acdc 100644 --- a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py +++ b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py @@ -241,6 +241,13 @@ def get_image_size_with_most_features(self) -> tuple[int, int]: return (image_size, image_size) +def _upsample_token_grid_nearest(token_ids: torch.Tensor, grid_h: int, grid_w: int) -> torch.Tensor: + """Upsample a 2D token grid by 2x using integer nearest-neighbor.""" + token_grid = token_ids.view(grid_h, grid_w) + token_grid = token_grid.repeat_interleave(2, dim=0).repeat_interleave(2, dim=1) + return token_grid.reshape(-1) + + class GlmImageDummyInputsBuilder(BaseDummyInputsBuilder[GlmImageProcessingInfo]): """ Builds dummy inputs for GLM-Image model profiling. @@ -2231,10 +2238,7 @@ def forward( for i, tokens in enumerate(image_tokens_list): grid_t, grid_h, grid_w = image_grid_thw[i].tolist() # Reshape to 2D grid - tokens_2d = tokens.view(1, 1, grid_h, grid_w) - # Upsample by 2x (nearest neighbor) - tokens_upsampled = F.interpolate(tokens_2d.float(), scale_factor=2, mode="nearest").to(dtype=torch.long) - upsampled_token_ids.append(tokens_upsampled.view(-1)) + upsampled_token_ids.append(_upsample_token_grid_nearest(tokens, grid_h, grid_w)) prior_token_image_ids_info = { "prior_token_image_ids": upsampled_token_ids, @@ -2439,11 +2443,9 @@ def _process_image_input( for i, tokens in enumerate(image_tokens_list): grid_t, grid_h, grid_w = image_grid_thw[i].tolist() # Reshape to 2D grid - tokens_2d = tokens.view(1, 1, grid_h, grid_w) - # Upsample by 2x (nearest neighbor) - tokens_upsampled = F.interpolate(tokens_2d.float(), scale_factor=2, mode="nearest").to(dtype=torch.long) + tokens_upsampled = _upsample_token_grid_nearest(tokens, grid_h, grid_w) # Keep as CPU tensor for proper serialization through pooling_output - upsampled_token_ids.append(tokens_upsampled.view(-1).detach().cpu().contiguous()) + upsampled_token_ids.append(tokens_upsampled.detach().cpu().contiguous()) # Note: We only include prior_token_image_ids in the info dict. # image_grid_thw is NOT included because: diff --git a/vllm_omni/model_executor/stage_input_processors/glm_image.py b/vllm_omni/model_executor/stage_input_processors/glm_image.py index 3063620bf8..e7216bea3a 100644 --- a/vllm_omni/model_executor/stage_input_processors/glm_image.py +++ b/vllm_omni/model_executor/stage_input_processors/glm_image.py @@ -27,10 +27,11 @@ def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> Returns: Upsampled token IDs of shape [num_tokens * 4] """ - token_ids = token_ids.view(1, 1, token_h, token_w) - token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(dtype=torch.long) - token_ids = token_ids.view(-1) - return token_ids + token_grid = token_ids.view(token_h, token_w) + # Integer nearest-neighbor upsampling avoids the float cast/interpolate + # overhead in the AR -> diffusion bridge. + token_grid = token_grid.repeat_interleave(2, dim=0).repeat_interleave(2, dim=1) + return token_grid.reshape(-1) def _parse_generated_tokens( @@ -261,5 +262,4 @@ def ar2diffusion( diffusion_input[key] = original_prompt[key] diffusion_inputs.append(diffusion_input) - return diffusion_inputs