diff --git a/examples/offline_inference/glm_image/end2end.py b/examples/offline_inference/glm_image/end2end.py index 13bcd23f55a..7ae9478ba75 100644 --- a/examples/offline_inference/glm_image/end2end.py +++ b/examples/offline_inference/glm_image/end2end.py @@ -57,22 +57,22 @@ GLM_IMAGE_VISION_VOCAB_SIZE = 16512 # top_k should be vision_vocab_size -def compute_max_tokens(height: int, width: int, factor: int = 32) -> int: +def compute_max_tokens(height: int, width: int, factor: int = 32, is_i2i: bool = False) -> int: """ Compute max_new_tokens for GLM-Image AR generation. - GLM-Image generates tokens in this order for text-to-image: - 1. Small preview image (half resolution in each dimension) - 2. Large target image (full resolution) - 3. EOS token + GLM-Image generation differs by mode: + - text-to-image (t2i): small preview + large target + EOS + - image-to-image (i2i): large target + EOS Args: height: Target image height in pixels width: Target image width in pixels factor: Downsampling factor (32 for GLM-Image AR output) + is_i2i: Whether the request is image-to-image mode Returns: - Total number of tokens to generate (small + large + EOS) + Total number of tokens to generate for the specified mode """ # Large image tokens (target resolution) token_h = height // factor @@ -80,11 +80,15 @@ def compute_max_tokens(height: int, width: int, factor: int = 32) -> int: large_tokens = token_h * token_w # Small preview tokens (half resolution in each dimension) - small_h = token_h // 2 - small_w = token_w // 2 - small_tokens = small_h * small_w + import math - # Total: small + large + EOS + ratio = token_h / token_w if token_w > 0 else 1.0 + small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) + small_tokens = small_token_h * small_token_w + + if is_i2i: + return large_tokens + 1 return small_tokens + large_tokens + 1 @@ -282,14 +286,18 @@ def main(args: argparse.Namespace) -> None: # Compute max_tokens dynamically based on target image size target_height = prompt_dict.get("height", 1024) target_width = prompt_dict.get("width", 1024) - calculated_max_tokens = compute_max_tokens(target_height, target_width) + is_i2i = source_image is not None + calculated_max_tokens = compute_max_tokens(target_height, target_width, is_i2i=is_i2i) # Use calculated value unless user explicitly specified a different value # Default args.max_tokens is 16384 (very large), so prefer calculated value effective_max_tokens = calculated_max_tokens if args.max_tokens == 16384 else args.max_tokens if args.verbose: - print(f"AR max_tokens: {effective_max_tokens} (calculated: {calculated_max_tokens}, arg: {args.max_tokens})") + print( + f"AR max_tokens: {effective_max_tokens} " + f"(calculated: {calculated_max_tokens}, arg: {args.max_tokens}, mode: {'i2i' if is_i2i else 't2i'})" + ) # IMPORTANT: GLM-Image AR model requires these exact sampling parameters # from generation_config.json for proper image token generation. @@ -303,6 +311,12 @@ def main(args: argparse.Namespace) -> None: stop_token_ids=[GLM_IMAGE_EOS_TOKEN_ID], # 16385, CRITICAL for stopping seed=args.seed, detokenize=False, + # Keep target size available in runner/model for deterministic M-RoPE + # decode grids in t2i (no mm_features available in this path). + extra_args={ + "target_h": int(target_height), + "target_w": int(target_width), + }, ) # For diffusion stage, sampling_params contains diffusion-specific parameters diff --git a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py index 4190b1fbb13..0a9d45a8ccb 100644 --- a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py +++ b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py @@ -100,6 +100,10 @@ def mock_request(mocker: MockerFixture): request.stop_token_ids = None request.frequency_penalty = None request.presence_penalty = None + # Must be real Python objects (not MagicMock) so the code's explicit-field + # and extra_body checks work correctly. + request.model_fields_set = set() + request.extra_body = {} return request @@ -150,6 +154,7 @@ def test_preserves_yaml_defaults_when_no_request_params(serving_chat, mock_reque def test_request_temperature_overrides_yaml_default(serving_chat, mock_request): """Test that request temperature overrides YAML default.""" mock_request.temperature = 0.8 + mock_request.model_fields_set = {"temperature"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -162,6 +167,7 @@ def test_request_temperature_overrides_yaml_default(serving_chat, mock_request): def test_request_top_p_overrides_yaml_default(serving_chat, mock_request): """Test that request top_p overrides YAML default.""" mock_request.top_p = 0.95 + mock_request.model_fields_set = {"top_p"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -173,6 +179,7 @@ def test_request_top_p_overrides_yaml_default(serving_chat, mock_request): def test_request_max_tokens_overrides_yaml_default(serving_chat, mock_request): """Test that request max_tokens overrides YAML default.""" mock_request.max_tokens = 100 + mock_request.model_fields_set = {"max_tokens"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -189,6 +196,7 @@ def test_max_tokens_uses_yaml_default_when_not_specified(serving_chat, mock_requ def test_request_seed_overrides_yaml_default(serving_chat, mock_request): """Test that request seed overrides YAML default.""" mock_request.seed = 123 + mock_request.model_fields_set = {"seed"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -200,6 +208,7 @@ def test_request_seed_overrides_yaml_default(serving_chat, mock_request): def test_request_frequency_penalty_overrides(serving_chat, mock_request): """Test that request frequency_penalty is applied.""" mock_request.frequency_penalty = 0.5 + mock_request.model_fields_set = {"frequency_penalty"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -209,6 +218,7 @@ def test_request_frequency_penalty_overrides(serving_chat, mock_request): def test_request_presence_penalty_overrides(serving_chat, mock_request): """Test that request presence_penalty is applied.""" mock_request.presence_penalty = 0.3 + mock_request.model_fields_set = {"presence_penalty"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -235,6 +245,7 @@ def test_multiple_params_override_together(serving_chat, mock_request): mock_request.temperature = 0.7 mock_request.top_p = 0.85 mock_request.seed = 999 + mock_request.model_fields_set = {"max_tokens", "temperature", "top_p", "seed"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -275,6 +286,7 @@ def test_apply_request_overrides_applies_values(serving_chat, mock_request, defa """Test that _apply_request_overrides applies non-None request values.""" mock_request.temperature = 0.8 mock_request.seed = 123 + mock_request.model_fields_set = {"temperature", "seed"} result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request) @@ -304,6 +316,8 @@ def test_apply_overrides_empty_stop_list_preserves_default(serving_chat, mocker) request.stop_token_ids = None request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"stop"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -325,6 +339,8 @@ def test_apply_overrides_nonempty_stop_list_overrides_default(serving_chat, mock request.stop_token_ids = None request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"stop"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -367,6 +383,8 @@ def test_apply_overrides_nonempty_stop_token_ids_overrides_default(serving_chat, request.stop_token_ids = [100] # non-empty list — should override request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"stop_token_ids"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -392,6 +410,8 @@ def test_apply_overrides_mixed_empty_and_nonempty_lists(serving_chat, mocker): request.stop_token_ids = [100, 200] # non-empty — SHOULD override request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"temperature", "stop", "stop_token_ids"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -415,6 +435,8 @@ def test_apply_overrides_none_scalar_still_preserves_default(serving_chat, mocke request.stop_token_ids = None request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = set() + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -442,6 +464,8 @@ def test_apply_overrides_both_lists_empty_preserves_defaults(serving_chat, mocke request.stop_token_ids = [] request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"stop", "stop_token_ids"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -511,3 +535,165 @@ def test_get_comprehension_stage_index_raises_when_not_found(mocker: MockerFixtu with pytest.raises(ValueError, match="No comprehension stage"): instance._get_comprehension_stage_index() + + +# ============================================================================= +# Tests for _resolve_height_width_from_extra_body +# ============================================================================= + + +class TestResolveHeightWidth: + def test_explicit_height_width(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"height": 512, "width": 768}) + assert h == 512 + assert w == 768 + + def test_size_string(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "768x512"}) + assert w == 768 + assert h == 512 + + def test_size_string_uppercase(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "768X512"}) + assert w == 768 + assert h == 512 + + def test_size_fallback_when_height_missing(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "512x512", "width": 1024}) + # height is None -> size fallback fires and sets BOTH width and height + assert h == 512 + assert w == 512 + + def test_empty_extra_body(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({}) + assert h is None + assert w is None + + def test_invalid_size_format_ignored(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "invalid"}) + assert h is None + assert w is None + + +# ============================================================================= +# Tests for _apply_request_overrides with GLM-Image (max_tokens computation) +# ============================================================================= + + +class TestApplyRequestOverridesGLMImage: + """Test dynamic max_tokens computation for GLM-Image AR stage.""" + + @pytest.fixture + def glm_serving_chat(self, mock_engine_client, mocker: MockerFixture): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + instance = object.__new__(OmniOpenAIServingChat) + instance.engine_client = mock_engine_client + # Mock the image extraction to return no reference images (t2i by default) + instance._extract_diffusion_prompt_and_images_from_messages = mocker.MagicMock(return_value=("a cat", [])) + return instance + + @pytest.fixture + def glm_request(self, mocker: MockerFixture): + req = mocker.MagicMock() + req.temperature = None + req.top_p = None + req.top_k = None + req.max_tokens = None + req.min_tokens = None + req.seed = None + req.ignore_eos = None + req.stop = None + req.stop_token_ids = None + req.frequency_penalty = None + req.presence_penalty = None + req.extra_body = {"height": 1024, "width": 1024} + req.model_fields_set = set() + return req + + def test_t2i_computes_max_tokens(self, glm_serving_chat, glm_request, default_comprehension_params): + """t2i mode: max_tokens computed from height/width, no reference images.""" + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request) + # t2i 1024x1024 = 256 + 1024 + 1 = 1281 + assert result.max_tokens == 1281 + assert result.extra_args["target_h"] == 1024 + assert result.extra_args["target_w"] == 1024 + + def test_i2i_computes_fewer_tokens( + self, glm_serving_chat, glm_request, default_comprehension_params, mocker: MockerFixture + ): + """i2i mode: max_tokens should be smaller than t2i for same dimensions.""" + # Make it detect reference images + glm_serving_chat._extract_diffusion_prompt_and_images_from_messages = mocker.MagicMock( + return_value=("edit this", ["fake_image"]) + ) + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request) + # i2i 1024x1024 = 1024 + 1 = 1025 + assert result.max_tokens == 1025 + + def test_dynamic_max_tokens_overrides_user_value(self, glm_serving_chat, glm_request, default_comprehension_params): + """When height/width are provided, dynamic computation overrides user max_tokens.""" + glm_request.max_tokens = 500 + glm_request.model_fields_set = {"max_tokens"} + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request) + # Dynamic computation from height/width always wins when present + assert result.max_tokens == 1281 + + def test_no_height_width_preserves_default( + self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params + ): + """When no height/width in extra_body, keep YAML default max_tokens.""" + req = mocker.MagicMock() + req.temperature = None + req.top_p = None + req.top_k = None + req.max_tokens = None + req.min_tokens = None + req.seed = None + req.ignore_eos = None + req.stop = None + req.stop_token_ids = None + req.frequency_penalty = None + req.presence_penalty = None + req.extra_body = {} + req.model_fields_set = set() + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req) + assert result.max_tokens == 2048 # YAML default + + def test_size_string_parsed_for_glm_image( + self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params + ): + """'size' in extra_body is parsed as fallback for height/width.""" + req = mocker.MagicMock() + req.temperature = None + req.top_p = None + req.top_k = None + req.max_tokens = None + req.min_tokens = None + req.seed = None + req.ignore_eos = None + req.stop = None + req.stop_token_ids = None + req.frequency_penalty = None + req.presence_penalty = None + req.extra_body = {"size": "512x512"} + req.model_fields_set = set() + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req) + # 512x512 t2i = 256 + 256 + 1 = 513 + assert result.max_tokens == 513 diff --git a/tests/model_executor/models/glm_image/test_glm_image_ar.py b/tests/model_executor/models/glm_image/test_glm_image_ar.py new file mode 100644 index 00000000000..32a016b2a67 --- /dev/null +++ b/tests/model_executor/models/glm_image/test_glm_image_ar.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for GLM-Image AR model: DataParser, processor, and M-RoPE.""" + +import importlib.util +import os +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest +import torch + +# --------------------------------------------------------------------------- +# Load target classes via importlib to avoid requiring transformers.models.glm_image +# (which may not exist in CI). This follows the same pattern as +# tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py. +# --------------------------------------------------------------------------- + +_BASE = os.path.join( + os.path.dirname(__file__), + os.pardir, + os.pardir, + os.pardir, + os.pardir, + "vllm_omni", + "model_executor", + "models", + "glm_image", +) + + +def _load_module(name: str, filename: str): + path = os.path.abspath(os.path.join(_BASE, filename)) + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def _build_mock_modules() -> dict[str, object]: + """Build the dict of modules to inject into sys.modules.""" + # Stub transformers.models.glm_image submodules + glm_image_mod = types.ModuleType("transformers.models.glm_image") + glm_config_mod = types.ModuleType("transformers.models.glm_image.configuration_glm_image") + glm_config_mod.GlmImageConfig = type("GlmImageConfig", (), {}) + glm_config_mod.GlmImageTextConfig = type("GlmImageTextConfig", (), {}) + glm_config_mod.GlmImageVisionConfig = type("GlmImageVisionConfig", (), {}) + glm_config_mod.GlmImageVQVAEConfig = type("GlmImageVQVAEConfig", (), {}) + glm_proc_mod = types.ModuleType("transformers.models.glm_image.processing_glm_image") + glm_proc_mod.GlmImageProcessor = type("GlmImageProcessor", (), {}) + + # vllm_omni submodules needed by the import chain + vllm_omni_mod = MagicMock() + vllm_omni_models = types.ModuleType("vllm_omni.model_executor.models") + vllm_omni_glm_image_pkg = types.ModuleType("vllm_omni.model_executor.models.glm_image") + vllm_omni_glm_image_pkg.__path__ = [os.path.abspath(_BASE)] + vllm_omni_output = MagicMock() + + return { + "transformers.models.glm_image": glm_image_mod, + "transformers.models.glm_image.configuration_glm_image": glm_config_mod, + "transformers.models.glm_image.processing_glm_image": glm_proc_mod, + "vllm_omni": vllm_omni_mod, + "vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"), + "vllm_omni.model_executor.models": vllm_omni_models, + "vllm_omni.model_executor.models.glm_image": vllm_omni_glm_image_pkg, + "vllm_omni.model_executor.models.output_templates": vllm_omni_output, + } + + +def _load_target_classes(): + """Load the glm_image_ar module with mocked dependencies.""" + mocks = _build_mock_modules() + with patch.dict(sys.modules, mocks): + mod = _load_module( + "vllm_omni.model_executor.models.glm_image.glm_image_ar", + "glm_image_ar.py", + ) + sys.modules["vllm_omni.model_executor.models.glm_image.glm_image_ar"] = mod + return mod + + +_ar_mod = _load_target_classes() + +GlmImageDataParser = _ar_mod.GlmImageDataParser +GlmImageMultiModalProcessor = _ar_mod.GlmImageMultiModalProcessor +GlmImageForConditionalGeneration = _ar_mod.GlmImageForConditionalGeneration +GlmImageRotaryEmbedding = _ar_mod.GlmImageRotaryEmbedding + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# ============================================================================= +# Helper: Minimal config for testing +# ============================================================================= + + +def _make_hf_config(**overrides): + """Create a minimal GlmImageConfig-like object for testing.""" + defaults = { + "image_token_id": 167855, + "image_start_token_id": 16384, + "image_end_token_id": 16385, + "grid_bos_token_id": None, + "grid_eos_token_id": None, + } + defaults.update(overrides) + from types import SimpleNamespace + + return SimpleNamespace(**defaults) + + +# ============================================================================= +# Tests for GlmImageDataParser +# ============================================================================= + + +class TestGlmImageDataParser: + """Test that img2img key is normalized to image in the data parser.""" + + def test_img2img_normalized_to_image(self): + parser = GlmImageDataParser.__new__(GlmImageDataParser) + parser._expected_hidden_size = 4096 + # The _get_subparsers should include img2img + subparsers = parser._get_subparsers() + assert "img2img" in subparsers + assert subparsers["img2img"] == parser._parse_image_data + + def test_parse_mm_data_normalizes_img2img(self): + parser = GlmImageDataParser.__new__(GlmImageDataParser) + parser._expected_hidden_size = 4096 + # Create a mock for the parent parse_mm_data + original_parse = type(parser).parse_mm_data + + calls = [] + + def mock_parse(mm_data, **kwargs): + calls.append(mm_data) + return MagicMock() + + # Monkey-patch temporarily + type(parser).parse_mm_data = mock_parse + try: + parser.parse_mm_data({"img2img": "fake_image"}) + except Exception: + pass # parse might fail on mock, we just check the normalization + finally: + type(parser).parse_mm_data = original_parse + + # Verify that "img2img" was normalized to "image" + if calls: + assert "image" in calls[0] + assert "img2img" not in calls[0] + + +# ============================================================================= +# Tests for _build_generation_grids +# ============================================================================= + + +class TestBuildGenerationGrids: + """Test M-RoPE grid construction for t2i mode.""" + + @pytest.fixture + def processor(self): + """Create a minimal processor instance with mocked info.""" + proc = object.__new__(GlmImageMultiModalProcessor) + proc.info = MagicMock() + return proc + + def test_1024x1024(self, processor): + kwargs = {"target_h": 1024, "target_w": 1024} + grids = processor._build_generation_grids(kwargs) + # token_h = 32, token_w = 32 + # ratio = 1.0, small_h = 16, small_w = 16 + assert grids.shape == (2, 3) + assert grids[0].tolist() == [1, 32, 32] # large + assert grids[1].tolist() == [1, 16, 16] # small + + def test_512x512(self, processor): + kwargs = {"target_h": 512, "target_w": 512} + grids = processor._build_generation_grids(kwargs) + assert grids.shape == (2, 3) + assert grids[0].tolist() == [1, 16, 16] + # small: ratio=1.0, small_h=int(sqrt(1)*16)=16, small_w=16 + assert grids[1].tolist() == [1, 16, 16] + + def test_non_square(self, processor): + kwargs = {"target_h": 1024, "target_w": 512} + grids = processor._build_generation_grids(kwargs) + # token_h = 32, token_w = 16, ratio = 2.0 + # small_h = int(sqrt(2)*16) = 22, small_w = int(sqrt(0.5)*16) = 11 + assert grids[0].tolist() == [1, 32, 16] + assert grids[1].tolist() == [1, 22, 11] + + def test_defaults_to_1024_when_no_target(self, processor): + kwargs = {} + grids = processor._build_generation_grids(kwargs) + assert grids[0].tolist() == [1, 32, 32] + + def test_height_width_fallback(self, processor): + kwargs = {"height": 512, "width": 512} + grids = processor._build_generation_grids(kwargs) + assert grids[0].tolist() == [1, 16, 16] + + def test_aligned_to_factor(self, processor): + # 1000 not aligned to 32, should be rounded down to 992 + kwargs = {"target_h": 1000, "target_w": 1000} + grids = processor._build_generation_grids(kwargs) + # 1000 // 32 = 31 + assert grids[0].tolist() == [1, 31, 31] + + +# ============================================================================= +# Tests for get_mrope_input_positions +# ============================================================================= + + +class TestGetMropeInputPositions: + """Test M-RoPE position ID computation.""" + + @pytest.fixture + def model(self): + """Create a minimal model instance for M-RoPE testing.""" + model = object.__new__(GlmImageForConditionalGeneration) + model.config = _make_hf_config() + return model + + def test_pure_text(self, model): + """Pure text tokens: all 3 dimensions get same sequential positions.""" + input_tokens = [100, 101, 102, 103] + positions, delta = model.get_mrope_input_positions(input_tokens) + assert positions.shape == (3, 4) + # All three dims should be [0, 1, 2, 3] + for dim in range(3): + assert positions[dim].tolist() == [0, 1, 2, 3] + assert delta == 0 # max(3) + 1 - seq_len(4) = 0 + + def test_t2i_with_target_size(self, model): + """t2i with explicit target_h/target_w: grids built from them.""" + input_tokens = [100, 101, 102, 16384] # text + + kwargs = {"target_h": 256, "target_w": 256} + + positions, delta = model.get_mrope_input_positions(input_tokens, **kwargs) + # 256/32=8 -> grids = [[1,8,8], [1,16,16]] (small uses factor//2=16 base) + # Decode order (reversed): grid[-1]=[1,16,16]=256, grid[-2]=[1,8,8]=64, EOS=1 + total_decode = 256 + 64 + 1 # 321 + assert positions.shape == (3, 4 + total_decode) + # delta = max_position + 1 - seq_len + # Positions advance by max(h,w) per grid: max(16,16)=16, max(8,8)=8 + # max_pos = seq_len(4) + 16 + 8 = 28, then EOS at 28 + # delta = 28 + 1 - 4 = 25 + assert delta == 25 + + def test_t2i_1024_default_grids(self, model): + """t2i with default 1024x1024 grids when no explicit target size.""" + # Prompt ending with image_start_token_id but no image_end_token_id + input_tokens = [100, 101, 16384] + # No target_h/target_w, no mrope_image_grid_thw + # Falls back to token parsing then to default [[1,32,32], [1,16,16]] + positions, delta = model.get_mrope_input_positions(input_tokens) + assert positions.shape[0] == 3 + + def test_i2i_with_mrope_grid(self, model): + """i2i: mrope_image_grid_thw contains source + target grids.""" + # Source image tokens: [16384, 167855*4, 16385] + text + 16384(bos) + source_grid = [1, 2, 2] # 2x2 = 4 image tokens + target_grid = [1, 32, 32] # 32x32 = 1024 tokens + mrope_grid = torch.tensor([source_grid, target_grid], dtype=torch.long) + + # input_tokens: text + + 4*image_token + + + input_tokens = [100, 101, 16384] + [167855] * 4 + [16385, 16384] + + positions, delta = model.get_mrope_input_positions(input_tokens, mrope_image_grid_thw=mrope_grid) + + # 1 source image (num_complete_images=1), 1 target grid (num_decode_grids=1) + # Prefill covers all input tokens + # Decode covers: 32*32 + 1(EOS) = 1025 tokens + assert positions.shape[0] == 3 + + def test_position_delta_non_negative(self, model): + """mrope_position_delta should be non-negative for valid inputs.""" + input_tokens = [100, 16384] + kwargs = {"target_h": 64, "target_w": 64} + positions, delta = model.get_mrope_input_positions(input_tokens, **kwargs) + assert delta >= 0 + + +# ============================================================================= +# Tests for GlmImageRotaryEmbedding._apply_mrope +# ============================================================================= + + +class TestGlmImageRotaryEmbedding: + """Test M-RoPE section interleaving in the rotary embedding.""" + + @pytest.fixture + def rotary_emb(self): + # mrope_section=[8,12,12] sums to 32, so rotary_dim//2 must be >= 32 + # -> head_dim=64 gives rotary_dim=64, rotary_dim//2=32 + return GlmImageRotaryEmbedding(head_dim=64, mrope_section=[8, 12, 12]) + + def test_apply_mrope_shape(self, rotary_emb): + """Output shape matches [num_tokens, rotary_dim // 2].""" + freqs = torch.randn(3, 5, 32) # 3 dims, 5 tokens, rotary_dim//2=32 + result = rotary_emb._apply_mrope(freqs) + assert result.shape == (5, 32) + + def test_apply_mrope_interleaving(self, rotary_emb): + """Verify that M-RoPE correctly interleaves T/H/W sections.""" + # mrope_section = [8, 12, 12] splits dim 32 into 3 chunks: [8, 12, 12] + # chunk 0 (size 8): dim 0 % 3 = 0 (temporal) + # chunk 1 (size 12): dim 1 % 3 = 1 (height) + # chunk 2 (size 12): dim 2 % 3 = 2 (width) + freqs = torch.ones(3, 1, 32) + freqs[0, :, :] = 1.0 # temporal + freqs[1, :, :] = 2.0 # height + freqs[2, :, :] = 3.0 # width + + result = rotary_emb._apply_mrope(freqs) + assert result.shape == (1, 32) + assert (result[0, :8] == 1.0).all() # chunk 0: temporal + assert (result[0, 8:20] == 2.0).all() # chunk 1: height + assert (result[0, 20:32] == 3.0).all() # chunk 2: width + + def test_forward_1d_positions(self, rotary_emb): + """Forward with 1D positions (text-only) produces correct shapes.""" + positions = torch.arange(10) # [10] + q = torch.randn(10, 64) + k = torch.randn(10, 64) + q_out, k_out = rotary_emb(positions, q, k) + assert q_out.shape == (10, 64) + assert k_out.shape == (10, 64) + + def test_forward_3d_positions(self, rotary_emb): + """Forward with 3D M-RoPE positions produces correct shapes.""" + positions = torch.arange(30).reshape(3, 10) # [3, 10] + q = torch.randn(10, 64) + k = torch.randn(10, 64) + q_out, k_out = rotary_emb(positions, q, k) + assert q_out.shape == (10, 64) + assert k_out.shape == (10, 64) + + def test_forward_preserves_dtype(self, rotary_emb): + """Output dtype matches input dtype.""" + positions = torch.arange(5) + q = torch.randn(5, 64, dtype=torch.float32) + k = torch.randn(5, 64, dtype=torch.float32) + q_out, k_out = rotary_emb(positions, q, k) + assert q_out.dtype == torch.float32 + assert k_out.dtype == torch.float32 diff --git a/tests/model_executor/stage_input_processors/test_glm_image.py b/tests/model_executor/stage_input_processors/test_glm_image.py new file mode 100644 index 00000000000..028d3e7cd9d --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_glm_image.py @@ -0,0 +1,403 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for GLM-Image stage input processor.""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors.glm_image import ( + _first_source_image, + _has_source_image, + _parse_generated_tokens, + _upsample_token_ids, + ar2diffusion, + compute_max_tokens, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _source_output(token_ids: list[int], mm_output: dict | None = None): + """Create a minimal AR output mock.""" + return SimpleNamespace( + outputs=[SimpleNamespace(token_ids=token_ids)], + multimodal_output=mm_output, + ) + + +def _stage_with_outputs(outputs): + """Create a stage list entry with engine_outputs.""" + return SimpleNamespace(engine_outputs=outputs) + + +# ============================================================================= +# Tests for _has_source_image +# ============================================================================= + + +class TestHasSourceImage: + def test_none_input(self): + assert _has_source_image(None) is False + + def test_non_dict_input(self): + assert _has_source_image("not_a_dict") is False + + def test_empty_dict(self): + assert _has_source_image({}) is False + + def test_image_key_present(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _has_source_image({"image": img}) is True + + def test_image_key_none(self): + assert _has_source_image({"image": None}) is False + + def test_img2img_key_present(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _has_source_image({"img2img": img}) is True + + def test_images_key_list(self): + from PIL import Image + + imgs = [Image.new("RGB", (64, 64))] + assert _has_source_image({"images": imgs}) is True + + def test_images_key_empty_list(self): + assert _has_source_image({"images": []}) is False + + def test_images_key_single(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _has_source_image({"images": img}) is True + + +# ============================================================================= +# Tests for _first_source_image +# ============================================================================= + + +class TestFirstSourceImage: + def test_none_input(self): + assert _first_source_image(None) is None + + def test_non_dict_input(self): + assert _first_source_image("not_a_dict") is None + + def test_image_key_single(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _first_source_image({"image": img}) is img + + def test_image_key_list(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _first_source_image({"image": [img]}) is img + + def test_image_key_empty_list(self): + assert _first_source_image({"image": []}) is None + + def test_img2img_key_single(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _first_source_image({"img2img": img}) is img + + def test_images_key_list(self): + from PIL import Image + + imgs = [Image.new("RGB", (64, 64))] + assert _first_source_image({"images": imgs}) is imgs[0] + + def test_images_key_empty_list(self): + assert _first_source_image({"images": []}) is None + + def test_images_key_single_not_list(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _first_source_image({"images": img}) is img + + +# ============================================================================= +# Tests for compute_max_tokens +# ============================================================================= + + +class TestComputeMaxTokens: + def test_t2i_1024x1024(self): + # t2i: small_tokens + large_tokens + 1 (EOS) + # token_h = 1024/32 = 32, token_w = 1024/32 = 32 + # large = 32*32 = 1024 + # ratio = 1.0, small_h = sqrt(1)*16 = 16, small_w = sqrt(1)*16 = 16, small = 256 + # total = 256 + 1024 + 1 = 1281 + result = compute_max_tokens(1024, 1024, is_i2i=False) + assert result == 1281 + + def test_i2i_1024x1024(self): + # i2i: large_tokens + 1 (EOS) + # large = 32*32 = 1024, total = 1025 + result = compute_max_tokens(1024, 1024, is_i2i=True) + assert result == 1025 + + def test_t2i_512x512(self): + # token_h = 16, token_w = 16, large = 256 + # ratio = 1.0, small_h = 16, small_w = 16, small = 256 + # total = 256 + 256 + 1 = 513 + result = compute_max_tokens(512, 512, is_i2i=False) + assert result == 513 + + def test_i2i_512x512(self): + # large = 256, total = 257 + result = compute_max_tokens(512, 512, is_i2i=True) + assert result == 257 + + def test_non_square_t2i(self): + # 1024x512: token_h=32, token_w=16, large=512 + # ratio = 32/16 = 2.0 + # small_h = max(1, int(sqrt(2)*16)) = 22, small_w = max(1, int(sqrt(0.5)*16)) = 11 + # small = 22*11 = 242 + # total = 242 + 512 + 1 = 755 + result = compute_max_tokens(1024, 512, is_i2i=False) + assert result == 242 + 512 + 1 + + def test_custom_factor(self): + # factor=16, 512x512: token_h=32, token_w=32, large=1024 + # ratio=1.0, small_h=8, small_w=8, small=64 + # total = 64 + 1024 + 1 = 1089 + result = compute_max_tokens(512, 512, factor=16, is_i2i=False) + assert result == 1089 + + def test_i2i_smaller_than_t2i(self): + t2i = compute_max_tokens(1024, 1024, is_i2i=False) + i2i = compute_max_tokens(1024, 1024, is_i2i=True) + assert i2i < t2i + + +# ============================================================================= +# Tests for _upsample_token_ids +# ============================================================================= + + +class TestUpsampleTokenIds: + def test_2x2_to_4x4(self): + tokens = torch.tensor([1, 2, 3, 4]) + result = _upsample_token_ids(tokens, 2, 2) + assert result.shape == (16,) # 4 * 4 = 16 (2x each dim) + + def test_1x1_to_2x2(self): + tokens = torch.tensor([7]) + result = _upsample_token_ids(tokens, 1, 1) + assert result.shape == (4,) # 2 * 2 + assert (result == 7).all() + + def test_4x4_to_8x8(self): + tokens = torch.arange(16, dtype=torch.long) + result = _upsample_token_ids(tokens, 4, 4) + assert result.shape == (64,) + + def test_preserves_dtype(self): + tokens = torch.tensor([1, 2, 3, 4], dtype=torch.long) + result = _upsample_token_ids(tokens, 2, 2) + assert result.dtype == torch.long + + +# ============================================================================= +# Tests for _parse_generated_tokens +# ============================================================================= + + +class TestParseGeneratedTokens: + def test_t2i_standard(self): + # 1024x1024, t2i: small(256) + large(1024) + EOS + # Generate 256 + 1024 + 1 = 1281 tokens, last is EOS (16385) + large_tokens = list(range(1024)) + small_tokens = list(range(1000, 1256)) + eos = [16385] + token_ids = small_tokens + large_tokens + eos + + prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=False) + assert h == 1024 + assert w == 1024 + # Prior tokens should be upsampled: 1024 tokens -> 4*1024 = 4096 + assert prior.shape[0] == 1024 * 4 + + def test_i2i_standard(self): + # 1024x1024, i2i: large(1024) + EOS + large_tokens = list(range(1024)) + eos = [16385] + token_ids = large_tokens + eos + + prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=True) + assert h == 1024 + assert w == 1024 + assert prior.shape[0] == 1024 * 4 + + def test_i2i_without_eos(self): + # i2i without EOS marker + large_tokens = list(range(1024)) + prior, h, w = _parse_generated_tokens(large_tokens, 1024, 1024, is_i2i=True) + assert h == 1024 + assert w == 1024 + + def test_i2i_too_few_tokens_raises(self): + with pytest.raises(ValueError, match="i2i token parse failed"): + _parse_generated_tokens([1, 2, 3], 1024, 1024, is_i2i=True) + + def test_t2i_too_few_tokens_raises(self): + # Only large tokens, no small preview + large_tokens = list(range(1024)) + with pytest.raises(ValueError, match="t2i token parse failed"): + _parse_generated_tokens(large_tokens, 1024, 1024, is_i2i=False) + + def test_i2i_t2i_style_layout_fallback(self): + # i2i but got t2i-style (small + large) tokens + small_tokens = list(range(256)) + large_tokens = list(range(1024)) + token_ids = small_tokens + large_tokens + + prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=True) + # Should extract the large portion + assert h == 1024 + assert w == 1024 + + +# ============================================================================= +# Tests for ar2diffusion +# ============================================================================= + + +class TestAr2Diffusion: + def test_basic_t2i(self): + """Test basic text-to-image pipeline: AR -> Diffusion.""" + # 1024x1024 t2i: small(256) + large(1024) + EOS + token_ids = list(range(256)) + list(range(1024)) + [16385] + stage_list = [_stage_with_outputs([_source_output(token_ids)])] + + prompt = {"prompt": "a cat", "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}} + + result = ar2diffusion(stage_list, [0], prompt=[prompt]) + assert len(result) == 1 + assert result[0]["prompt"] == "a cat" + assert result[0]["height"] == 1024 + assert result[0]["width"] == 1024 + assert "prior_token_ids" in result[0]["extra"] + + def test_i2i_with_mm_output(self): + """Test image-to-image with prior_token_image_ids from AR model.""" + token_ids = list(range(1024)) + [16385] + mm_output = {"prior_token_image_ids": torch.tensor([1, 2, 3])} + stage_list = [_stage_with_outputs([_source_output(token_ids, mm_output)])] + + from PIL import Image + + img = Image.new("RGB", (64, 64)) + prompt = { + "prompt": "edit this", + "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}, + "multi_modal_data": {"image": img}, + } + + result = ar2diffusion(stage_list, [0], prompt=[prompt]) + assert len(result) == 1 + assert result[0]["extra"]["prior_token_image_ids"] is not None + + def test_i2i_detected_via_modalities(self): + """Test i2i mode detected via modalities field.""" + token_ids = list(range(1024)) + [16385] + stage_list = [_stage_with_outputs([_source_output(token_ids)])] + + prompt = { + "prompt": "edit this", + "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}, + "modalities": ["img2img"], + } + + result = ar2diffusion(stage_list, [0], prompt=[prompt]) + assert len(result) == 1 + + def test_empty_engine_input_source_raises(self): + with pytest.raises(ValueError, match="engine_input_source cannot be empty"): + ar2diffusion([], [], prompt={}) + + def test_invalid_stage_id_raises(self): + with pytest.raises(IndexError, match="Invalid stage_id"): + ar2diffusion([_stage_with_outputs(None)], [5], prompt={}) + + def test_no_outputs_raises(self): + with pytest.raises(RuntimeError, match="has no outputs yet"): + ar2diffusion([SimpleNamespace(engine_outputs=None)], [0], prompt={}) + + def test_default_dimensions(self): + """When no height/width in prompt, defaults to 1024x1024.""" + token_ids = list(range(256)) + list(range(1024)) + [16385] + stage_list = [_stage_with_outputs([_source_output(token_ids)])] + + prompt = {"prompt": "test"} + result = ar2diffusion(stage_list, [0], prompt=[prompt]) + assert result[0]["height"] == 1024 + assert result[0]["width"] == 1024 + + def test_requires_multimodal_data_with_pil_image(self): + """Test that pil_image is included when requires_multimodal_data=True.""" + token_ids = list(range(256)) + list(range(1024)) + [16385] + stage_list = [_stage_with_outputs([_source_output(token_ids)])] + + from PIL import Image + + img = Image.new("RGB", (64, 64)) + prompt = { + "prompt": "test", + "multi_modal_data": {"image": img}, + } + + result = ar2diffusion(stage_list, [0], prompt=[prompt], requires_multimodal_data=True) + assert result[0]["pil_image"] is img + + def test_extra_params_passed_through(self): + """Test that seed, num_inference_steps, guidance_scale, negative_prompt are passed.""" + token_ids = list(range(256)) + list(range(1024)) + [16385] + stage_list = [_stage_with_outputs([_source_output(token_ids)])] + + prompt = { + "prompt": "test", + "seed": 42, + "num_inference_steps": 50, + "guidance_scale": 7.5, + "negative_prompt": "blurry", + } + + result = ar2diffusion(stage_list, [0], prompt=[prompt]) + assert result[0]["seed"] == 42 + assert result[0]["num_inference_steps"] == 50 + assert result[0]["guidance_scale"] == 7.5 + assert result[0]["negative_prompt"] == "blurry" + + def test_batch_requests(self): + """Test processing multiple requests in a batch.""" + tokens1 = list(range(256)) + list(range(1024)) + [16385] + tokens2 = list(range(256)) + list(range(1024)) + [16385] + stage_list = [_stage_with_outputs([_source_output(tokens1), _source_output(tokens2)])] + + prompts = [ + {"prompt": "first", "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}}, + {"prompt": "second", "mm_processor_kwargs": {"target_h": 512, "target_w": 512}}, + ] + + result = ar2diffusion(stage_list, [0], prompt=prompts) + assert len(result) == 2 + assert result[0]["prompt"] == "first" + assert result[1]["prompt"] == "second" diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index a5d7ad032e9..e86388a388d 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -627,12 +627,21 @@ async def _forward_to_next_stage( if next_client.stage_type == "diffusion": self.stage_clients[stage_id].set_engine_outputs([output]) if next_client.custom_process_input_func is not None: + _t_ar2d = _time.perf_counter() diffusion_prompt = next_client.custom_process_input_func( self.stage_clients, next_client.engine_input_source, req_state.prompt, False, ) + _dt_ar2d = (_time.perf_counter() - _t_ar2d) * 1000 + logger.info( + "[Orchestrator] ar2diffusion req=%s wall_time=%.3fms stage=%d->%d", + req_id, + _dt_ar2d, + stage_id, + next_stage_id, + ) if isinstance(diffusion_prompt, list): diffusion_prompt = diffusion_prompt[0] else: diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 745b719d5b2..b2e6e435a95 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1404,6 +1404,16 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) else: size_str = "model default" + # Keep AR stage target grid in sync with requested output size. + # GLM-Image consumes target_h/target_w via mm_processor_kwargs. + if width is not None and height is not None: + prompt["mm_processor_kwargs"] = { + "target_h": height, + "target_w": width, + } + # Backward-compatible fallback for processors reading top-level fields. + prompt["height"] = height + prompt["width"] = width app_state_args = getattr(raw_request.app.state, "args", None) _check_max_generated_image_size(app_state_args, width, height) @@ -1629,6 +1639,18 @@ async def edit_images( _check_max_generated_image_size(app_state_args, width, height, resolution) size_str = f"{width}x{height}" if width is not None and height is not None else "auto" + + # Keep AR stage target grid in sync with requested output size. + # GLM-Image consumes target_h/target_w via mm_processor_kwargs. + if width is not None and height is not None: + prompt["mm_processor_kwargs"] = { + "target_h": height, + "target_w": width, + } + # Backward-compatible fallback for processors reading top-level fields. + prompt["height"] = height + prompt["width"] = width + _update_if_not_none(gen_params, "width", width) _update_if_not_none(gen_params, "height", height) diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 8cddac6a6c5..aee6c367431 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -304,20 +304,9 @@ async def create_chat_completion( # effectively unconditioned and produce nonsense images. if request.modalities and ("image" in request.modalities): try: - messages_as_dicts: list[dict[str, Any]] = [] - for msg in request.messages: - if hasattr(msg, "model_dump"): - messages_as_dicts.append(msg.model_dump()) - elif isinstance(msg, dict): - messages_as_dicts.append(msg) - else: - messages_as_dicts.append( - { - "role": getattr(msg, "role", "user"), - "content": getattr(msg, "content", ""), - } - ) - extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images(messages_as_dicts) + extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images_from_messages( + request.messages + ) if not extracted_prompt: return self.create_error_response("No text prompt found in messages") @@ -328,41 +317,33 @@ async def create_chat_completion( extra_body = getattr(request, "extra_body", None) if not extra_body: extra_body = request.model_extra or {} - height = extra_body.get("height") - width = extra_body.get("width") + + height, width = self._resolve_height_width_from_extra_body(extra_body) + num_inference_steps = extra_body.get("num_inference_steps") if num_inference_steps is not None: try: num_inference_steps = int(num_inference_steps) except Exception: num_inference_steps = None - if "size" in extra_body: - try: - size_str = extra_body["size"] - if isinstance(size_str, str) and "x" in size_str.lower(): - w, h = size_str.lower().split("x") - width, height = int(w), int(h) - except Exception: - pass + negative_prompt = extra_body.get("negative_prompt") cfg_text_scale = extra_body.get("cfg_text_scale") cfg_img_scale = extra_body.get("cfg_img_scale") engine_prompt_image: dict[str, Any] | None = None - is_img2img = False if reference_images: # Best-effort decode first reference image for i2i. try: img_bytes = base64.b64decode(reference_images[0]) img = Image.open(BytesIO(img_bytes)) engine_prompt_image = {"img2img": img} - is_img2img = True except Exception: engine_prompt_image = None # Override the prompts produced by chat-template preprocessing. tprompt: OmniTextPrompt = {"prompt": extracted_prompt} - if is_img2img: + if engine_prompt_image: tprompt["modalities"] = ["img2img"] else: tprompt["modalities"] = ["image"] @@ -378,6 +359,13 @@ async def create_chat_completion( tprompt["mm_processor_kwargs"] = mm_processor_kwargs if engine_prompt_image is not None: tprompt["multi_modal_data"] = engine_prompt_image + # Provide multi_modal_uuids so that newer vLLM versions + # can validate multi_modal_data / multi_modal_uuids + # consistency. After the multimodal processor consumes + # the image data, the uuids remain as a stable reference. + tprompt["multi_modal_uuids"] = { + k: [f"{request_id}-{k}-{i}"] for i, k in enumerate(engine_prompt_image) + } engine_prompts = [tprompt] # Store height/width for applying to diffusion stage sampling params later @@ -544,20 +532,7 @@ async def _preprocess_chat( # containing image tokens. req_modalities = getattr(request, "modalities", []) if req_modalities and ("image" in req_modalities): - messages_as_dicts: list[dict[str, Any]] = [] - for msg in messages: - if hasattr(msg, "model_dump"): - messages_as_dicts.append(msg.model_dump()) - elif isinstance(msg, dict): - messages_as_dicts.append(msg) - else: - messages_as_dicts.append( - { - "role": getattr(msg, "role", "user"), - "content": getattr(msg, "content", ""), - } - ) - extracted_prompt, _ = self._extract_diffusion_prompt_and_images(messages_as_dicts) + extracted_prompt, _ = self._extract_diffusion_prompt_and_images_from_messages(messages) if extracted_prompt: engine_prompt["prompt"] = extracted_prompt @@ -717,6 +692,9 @@ def _apply_request_overrides( Starts with YAML defaults and only overrides fields that the user explicitly provided (non-None values) in the request. + For GLM-Image AR stage, if max_tokens is not in YAML and user provides + height/width in extra_body, computes max_tokens dynamically. + Args: default_params: Default SamplingParams from stage config YAML. request: The chat completion request containing user-provided values. @@ -726,11 +704,56 @@ def _apply_request_overrides( """ params = default_params.clone() + # Only apply fields explicitly provided by user, not protocol defaults. + # Pydantic v2 uses `model_fields_set`; keep v1 fallback for compatibility. + explicit_fields = getattr(request, "model_fields_set", None) + if explicit_fields is None: + explicit_fields = getattr(request, "__fields_set__", set()) + for field_name in self._OPENAI_SAMPLING_FIELDS: + if field_name not in explicit_fields: + continue + value = getattr(request, field_name, None) if (value is not None and not isinstance(value, list)) or (isinstance(value, list) and len(value) > 0): setattr(params, field_name, value) + # For GLM-Image: compute max_tokens from height/width with mode-aware + # budgeting (t2i vs i2i). + extra_body = getattr(request, "extra_body", {}) or {} + height, width = self._resolve_height_width_from_extra_body(extra_body) + + # Best-effort mode detection from user messages. + # i2i requests include at least one reference image in message content. + _, reference_images = self._extract_diffusion_prompt_and_images_from_messages(request.messages) + ref_image_count = len(reference_images) + is_img2img = ref_image_count > 0 + + if height is not None and width is not None: + try: + from vllm_omni.model_executor.stage_input_processors.glm_image import compute_max_tokens + + max_tokens = getattr(explicit_fields, "max_tokens", None) + if max_tokens is None: + max_tokens = compute_max_tokens(int(height), int(width), is_i2i=is_img2img) + params.max_tokens = max_tokens + # Keep target size in stage-0 sampling params so runner/model can + # build deterministic M-RoPE grids for t2i (no MM features). + extra_args = dict(getattr(params, "extra_args", {}) or {}) + extra_args["target_h"] = int(height) + extra_args["target_w"] = int(width) + params.extra_args = extra_args + except (ImportError, ValueError, TypeError) as e: + logger.warning(f"Failed to compute max_tokens: {e}, using default if available") + else: + logger.info( + "[SamplingParams] Skip dynamic max_tokens (height=%s, width=%s, mode=%s, ref_images=%s)", + height, + width, + "i2i" if is_img2img else "t2i", + ref_image_count, + ) + return params @staticmethod @@ -2643,6 +2666,48 @@ def _extract_diffusion_prompt_and_images( prompt = " ".join(prompt_parts).strip() return prompt, images + def _extract_diffusion_prompt_and_images_from_messages( + self, + messages: list[Any], + ) -> tuple[str, list[str]]: + """Normalize mixed message types and extract prompt + reference images once.""" + return self._extract_diffusion_prompt_and_images(self._messages_to_dicts(messages)) + + @staticmethod + def _messages_to_dicts(messages: list[Any]) -> list[dict[str, Any]]: + """Normalize request messages to plain dicts.""" + out: list[dict[str, Any]] = [] + for msg in messages: + if hasattr(msg, "model_dump"): + out.append(msg.model_dump()) + elif isinstance(msg, dict): + out.append(msg) + else: + out.append( + { + "role": getattr(msg, "role", "user"), + "content": getattr(msg, "content", ""), + } + ) + return out + + @staticmethod + def _resolve_height_width_from_extra_body(extra_body: dict[str, Any]) -> tuple[Any, Any]: + """Extract generation height/width with optional size string fallback.""" + height = extra_body.get("height") + width = extra_body.get("width") + + if "size" in extra_body and (height is None or width is None): + try: + size_str = extra_body["size"] + if isinstance(size_str, str) and "x" in size_str.lower(): + w, h = size_str.lower().split("x") + width, height = int(w), int(h) + except Exception: + pass + + return height, width + def _create_error_response( self, message: str, diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index c6dffd05426..cca6ce56870 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -29,6 +29,8 @@ def _process_text( self, parsed_content: OmniTextPrompt, tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: Any | None = None, ) -> OmniTokenInputs | MultiModalInput: """Process text prompts with support for mm_processor_kwargs. @@ -38,6 +40,10 @@ def _process_text( """ prompt_text = parsed_content["prompt"] mm_processor_kwargs = parsed_content.get("mm_processor_kwargs") or {} + # When the deprecated raw-prompt path is used, process_inputs does + # not pass mm_uuids to preprocess(). Fall back to reading it from + # the prompt dict so the Renderer's _validate_mm_uuids can see it. + effective_mm_uuids = mm_uuids or parsed_content.get("multi_modal_uuids") inputs: OmniTokenInputs | MultiModalInput if multi_modal_data := parsed_content.get("multi_modal_data"): @@ -46,6 +52,7 @@ def _process_text( multi_modal_data, mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=effective_mm_uuids, ) prompt_embeds = parsed_content.get("prompt_embeds") if prompt_embeds is not None: @@ -59,6 +66,7 @@ def _process_text( {}, mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=effective_mm_uuids, ) else: prompt_token_ids = self._tokenize_prompt( @@ -142,6 +150,8 @@ def _prompt_to_llm_inputs( self, prompt: SingletonDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: Any | None = None, ) -> SingletonInput: """ Extract the singleton inputs from a prompt. @@ -166,6 +176,7 @@ def _prompt_to_llm_inputs( return self._process_text( prompt, # type: ignore[arg-type] tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) assert_never(prompt) # type: ignore[arg-type] diff --git a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py index 31eed9b2cb9..bf21c01a645 100644 --- a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py +++ b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py @@ -21,6 +21,7 @@ # limitations under the License. """Inference-only GLM-Image model compatible with HuggingFace weights.""" +import math import os from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Literal @@ -127,6 +128,14 @@ def _get_subparsers(self): parsers["img2img"] = self._parse_image_data return parsers + def parse_mm_data(self, mm_data, **kwargs): + # Normalize "img2img" to "image" so the rest of the pipeline + # (mm_hashes, _merge_mm_kwargs) uses a single modality key. + normalized = {} + for k, v in mm_data.items(): + normalized["image" if k == "img2img" else k] = v + return super().parse_mm_data(normalized, **kwargs) + class GlmImageProcessingInfo(BaseProcessingInfo): """ @@ -346,6 +355,10 @@ def _call_hf_processor( target_h = mm_kwargs.get("target_h", 1024) if mm_kwargs else 1024 target_w = mm_kwargs.get("target_w", 1024) if mm_kwargs else 1024 + logger.debug( + f"_call_hf_processor: target dimensions for generation: {target_h}x{target_w}, mm_kwargs={mm_kwargs}" + ) + if not mm_data or not mm_data.get("images"): # Text-to-image mode if processor is not None: @@ -566,6 +579,58 @@ def _apply_hf_processor_mm_only( tensor_type="pt", ) + def _apply_hf_processor_text_only( + self, prompt_text: str, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] + ) -> list[int]: + prompt_ids, _, _ = super()._apply_hf_processor_text_mm( + prompt_text=prompt_text, + mm_items=MultiModalDataItems({}), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + return prompt_ids + + def _build_generation_grids(self, hf_processor_mm_kwargs: Mapping[str, object]) -> torch.Tensor: + """Build generation grids for M-RoPE decode positions. + + For GLM-Image generation, decode order is: + 1) small preview grid + 2) large target grid + 3) EOS + + We store grids as [large, small] to match HF processor behavior, and + decode logic consumes them in reverse order. + """ + + target_h = ( + hf_processor_mm_kwargs.get("target_h") if isinstance(hf_processor_mm_kwargs.get("target_h"), int) else None + ) + target_w = ( + hf_processor_mm_kwargs.get("target_w") if isinstance(hf_processor_mm_kwargs.get("target_w"), int) else None + ) + if target_h is None or target_w is None: + target_h = ( + hf_processor_mm_kwargs.get("height") if isinstance(hf_processor_mm_kwargs.get("height"), int) else 1024 + ) + target_w = ( + hf_processor_mm_kwargs.get("width") if isinstance(hf_processor_mm_kwargs.get("width"), int) else 1024 + ) + + factor = 32 + target_h = (target_h // factor) * factor + target_w = (target_w // factor) * factor + token_h = target_h // factor + token_w = target_w // factor + + ratio = token_h / token_w if token_w > 0 else 1.0 + small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) + + return torch.tensor( + [[1, token_h, token_w], [1, small_token_h, small_token_w]], + dtype=torch.long, + ) + def _apply_hf_processor_main( self, prompt: str | list[int], @@ -594,126 +659,145 @@ def _apply_hf_processor_main( logger.debug(f"_apply_hf_processor_main: mm_counts={mm_counts}, num_images={num_images}") - if num_images == 0 or enable_hf_prompt_update: + if num_images == 0 and isinstance(prompt, str): # t2i mode or normal flow - use parent implementation - return super()._apply_hf_processor_main( - prompt=prompt, + prompt_ids = self._apply_hf_processor_text_only( + prompt_text=prompt, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + mm_processed_data = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - enable_hf_prompt_update=enable_hf_prompt_update, ) - # i2i mode with enable_hf_prompt_update=False (cache miss scenario) - # We need to build prompt_ids with image placeholders - logger.debug(f"_apply_hf_processor_main: i2i mode with enable_hf_prompt_update=False, num_images={num_images}") - - # Get mm data from our overridden _apply_hf_processor_mm_only - mm_processed_data = self._apply_hf_processor_mm_only( - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, - ) - - # In this path we do NOT call HF apply_chat_template, so we must still - # provide full grids (source + target) for M-RoPE to compute decode positions. - # Keep `image_grid_thw` source-only for MM batching/validation. - try: - source_grid_thw = mm_processed_data.get("image_grid_thw") - if source_grid_thw is not None and isinstance(source_grid_thw, torch.Tensor): - # Compute target grid following HF GlmImageProcessor: factor=32. - # Prefer explicit target_h/target_w if present, otherwise fall back. - target_h = ( - hf_processor_mm_kwargs.get("target_h") - if isinstance(hf_processor_mm_kwargs.get("target_h"), int) - else None - ) - target_w = ( - hf_processor_mm_kwargs.get("target_w") - if isinstance(hf_processor_mm_kwargs.get("target_w"), int) - else None + # t2i has no source images, so mm features cannot provide image_grid_thw. + # Provide explicit generation grids for M-RoPE to avoid fallback token parsing + # (which can degrade high-resolution spatial positions, e.g. 1920x1920). + try: + mrope_grid_thw = self._build_generation_grids(hf_processor_mm_kwargs) + mm_processed_data["mrope_image_grid_thw"] = mrope_grid_thw + logger.info( + "_apply_hf_processor_main t2i: mrope_image_grid_thw=%s", + mrope_grid_thw.tolist(), ) - if target_h is None or target_w is None: - # Some callers pass generation size as height/width. - target_h = ( - hf_processor_mm_kwargs.get("height") - if isinstance(hf_processor_mm_kwargs.get("height"), int) - else 1024 - ) - target_w = ( - hf_processor_mm_kwargs.get("width") - if isinstance(hf_processor_mm_kwargs.get("width"), int) - else 1024 - ) + except Exception as e: + logger.warning("_apply_hf_processor_main t2i: failed to set mrope_image_grid_thw: %s", e) - factor = 32 - target_h = (target_h // factor) * factor - target_w = (target_w // factor) * factor - token_h = target_h // factor - token_w = target_w // factor - target_grid = torch.tensor([[1, token_h, token_w]], dtype=source_grid_thw.dtype) + return prompt_ids, mm_processed_data, False - mm_processed_data["mrope_image_grid_thw"] = torch.cat([source_grid_thw, target_grid], dim=0) - except Exception: - # Best-effort only; M-RoPE has additional fallbacks. - pass + # i2i mode: use unified HF processor path only. + # This avoids drift between duplicated manual/HF i2i implementations. + logger.debug( + "_apply_hf_processor_main: i2i mode (enable_hf_prompt_update=%s), num_images=%s", + enable_hf_prompt_update, + num_images, + ) - # Build prompt_ids with image placeholders - # _apply_prompt_updates will replace each [image_token_id] with expanded tokens - tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.convert_tokens_to_ids("<|image|>") + if not isinstance(prompt, str): + # Online OpenAI chat preprocessing can arrive here with tokenized + # prompts (list[int]) before serving_chat replaces engine prompt + # with the clean text prompt. Do not fail the whole request. + logger.warning( + "_apply_hf_processor_main i2i: got tokenized prompt type=%s; " + "using compatibility path for preprocessing", + type(prompt).__name__, + ) + + prompt_ids = list(prompt) + mm_processed_data = self._apply_hf_processor_mm_only( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) - if isinstance(prompt, str): - # Match HF GlmImageProcessor behavior: append target grid tokens + BOS. - # This helps M-RoPE/grid parsing and keeps i2i vs t2i behavior aligned. + # Preserve full grids for M-RoPE decode (source + target), while + # keeping image_grid_thw source-only for MM batching. try: - grid_bos = getattr(tokenizer, "grid_bos_token", "") - grid_eos = getattr(tokenizer, "grid_eos_token", "") - bos = getattr(tokenizer, "bos_token", "") - - # Use the same target sizes we used for mrope grids when available. - target_h = ( - hf_processor_mm_kwargs.get("target_h") - if isinstance(hf_processor_mm_kwargs.get("target_h"), int) - else None - ) - target_w = ( - hf_processor_mm_kwargs.get("target_w") - if isinstance(hf_processor_mm_kwargs.get("target_w"), int) - else None - ) - if target_h is None or target_w is None: + source_grid_thw = mm_processed_data.get("image_grid_thw") + if source_grid_thw is not None and isinstance(source_grid_thw, torch.Tensor): target_h = ( - hf_processor_mm_kwargs.get("height") - if isinstance(hf_processor_mm_kwargs.get("height"), int) - else 1024 + hf_processor_mm_kwargs.get("target_h") + if isinstance(hf_processor_mm_kwargs.get("target_h"), int) + else None ) target_w = ( - hf_processor_mm_kwargs.get("width") - if isinstance(hf_processor_mm_kwargs.get("width"), int) - else 1024 + hf_processor_mm_kwargs.get("target_w") + if isinstance(hf_processor_mm_kwargs.get("target_w"), int) + else None ) + if target_h is None or target_w is None: + target_h = ( + hf_processor_mm_kwargs.get("height") + if isinstance(hf_processor_mm_kwargs.get("height"), int) + else 1024 + ) + target_w = ( + hf_processor_mm_kwargs.get("width") + if isinstance(hf_processor_mm_kwargs.get("width"), int) + else 1024 + ) - factor = 32 - target_h = (target_h // factor) * factor - target_w = (target_w // factor) * factor - token_h = target_h // factor - token_w = target_w // factor - - expanded_prompt = f"{prompt}{grid_bos}{token_h} {token_w}{grid_eos}{bos}" - text_ids = tokenizer.encode(expanded_prompt, add_special_tokens=False) + factor = 32 + token_h = max(1, target_h // factor) + token_w = max(1, target_w // factor) + target_grid = torch.tensor([[1, token_h, token_w]], dtype=source_grid_thw.dtype) + mm_processed_data["mrope_image_grid_thw"] = torch.cat([source_grid_thw, target_grid], dim=0) except Exception: - text_ids = tokenizer.encode(prompt, add_special_tokens=False) + pass + + # Prompt updates will expand image placeholders in this compatibility path. + return prompt_ids, mm_processed_data, False + + images = mm_items.get_items("image", ImageProcessorItems) + image_list = [images.get(i) for i in range(images.get_count())] + if not image_list: + raise ValueError("GLM-Image i2i requires at least one source image in mm_items") + + hf_inputs = self._call_hf_processor( + prompt=prompt, + mm_data={"images": image_list}, + mm_kwargs=hf_processor_mm_kwargs, + tok_kwargs=tokenization_kwargs, + ) + + input_ids = hf_inputs.get("input_ids") + if input_ids is None: + raise ValueError("HF i2i processor returned no input_ids") + + if isinstance(input_ids, torch.Tensor): + prompt_ids = input_ids[0].tolist() if input_ids.dim() > 1 else input_ids.tolist() else: - text_ids = list(prompt) + prompt_ids = ( + input_ids[0] + if isinstance(input_ids, list) and input_ids and isinstance(input_ids[0], list) + else list(input_ids) + ) - # Prepend image placeholders - one per image - prompt_ids = [image_token_id] * num_images + text_ids + mm_processed_data = BatchFeature(dict(), tensor_type="pt") + for key in ("pixel_values", "image_grid_thw", "mrope_image_grid_thw"): + value = hf_inputs.get(key) + if value is not None: + mm_processed_data[key] = value - logger.debug(f"_apply_hf_processor_main: built prompt_ids with {num_images} image placeholders") + image_grid_thw = mm_processed_data.get("image_grid_thw") + mrope_grid_thw = mm_processed_data.get("mrope_image_grid_thw") + hf_config = self.info.get_hf_config() + image_token_id = getattr(hf_config, "image_token_id", 167855) + image_token_count = prompt_ids.count(image_token_id) + logger.info( + "_apply_hf_processor_main i2i(HF): num_images=%s, prompt_len=%s, image_token_count=%s, " + "source_grid_shape=%s, mrope_grid_shape=%s", + num_images, + len(prompt_ids), + image_token_count, + tuple(image_grid_thw.shape) if image_grid_thw is not None else None, + tuple(mrope_grid_thw.shape) if mrope_grid_thw is not None else None, + ) - # Return is_update_applied=False so _apply_prompt_updates will expand the placeholders - return prompt_ids, mm_processed_data, False + # HF processor already expanded image placeholders in input_ids. + return prompt_ids, mm_processed_data, True def _get_mm_fields_config( self, @@ -2667,9 +2751,23 @@ def get_mrope_input_positions( # Input format: "textH Wh w" where =image_start_token_id=16384 # For 1024x1024: H=32, W=32 (large), h=16, w=16 (small preview) if not image_grid_thw: + # Preferred path for t2i: use explicit target size propagated from + # serving/request sampling params. This avoids fragile grid parsing + # from token IDs and matches HF processor grid construction. + target_h = kwargs.get("target_h") + target_w = kwargs.get("target_w") + if isinstance(target_h, int) and isinstance(target_w, int) and target_h > 0 and target_w > 0: + factor = 32 + token_h = target_h // factor + token_w = target_w // factor + ratio = token_h / token_w if token_w > 0 else 1.0 + small_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) + image_grid_thw = [[1, token_h, token_w], [1, small_h, small_w]] + # Try to parse from kwargs (passed from processor) hf_config_arg = kwargs.get("hf_config") - if hf_config_arg is not None and hasattr(hf_config_arg, "image_grid_thw"): + if (not image_grid_thw) and hf_config_arg is not None and hasattr(hf_config_arg, "image_grid_thw"): image_grid_thw = hf_config_arg.image_grid_thw # If still empty, try to infer from input tokens @@ -2723,19 +2821,29 @@ def get_mrope_input_positions( prompt_ends_with_start = len(input_tokens) > 0 and input_tokens[-1] == image_start_token_id if prompt_ends_with_start and len(image_grid_thw) == num_source_images and num_source_images > 0: # i2i mode: source grids exist but no target grids - # Parse target grids from prompt tokens or use defaults - parsed_grids = self._parse_grid_from_tokens(input_tokens, hf_config) - if parsed_grids: - # parsed_grids contains all grids mentioned in prompt - # For i2i, add only the generation target grids - if len(parsed_grids) > num_source_images: - image_grid_thw = list(image_grid_thw) + parsed_grids[num_source_images:] + # Prefer explicit target size propagated from request sampling params. + # This avoids fragile grid parsing from token IDs for non-1024 i2i. + target_h = kwargs.get("target_h") + target_w = kwargs.get("target_w") + if isinstance(target_h, int) and isinstance(target_w, int) and target_h > 0 and target_w > 0: + factor = 32 + token_h = target_h // factor + token_w = target_w // factor + image_grid_thw = list(image_grid_thw) + [[1, token_h, token_w]] + else: + # Parse target grids from prompt tokens or use defaults + parsed_grids = self._parse_grid_from_tokens(input_tokens, hf_config) + if parsed_grids: + # parsed_grids contains all grids mentioned in prompt + # For i2i, add only the generation target grids + if len(parsed_grids) > num_source_images: + image_grid_thw = list(image_grid_thw) + parsed_grids[num_source_images:] + else: + # Fallback: add default 1024x1024 generation grid (1 target for i2i) + image_grid_thw = list(image_grid_thw) + [[1, 32, 32]] else: - # Fallback: add default 1024x1024 generation grids (1 target for i2i) + # Fallback to default 1024x1024 grid for generation image_grid_thw = list(image_grid_thw) + [[1, 32, 32]] - else: - # Fallback to default 1024x1024 grids for generation - image_grid_thw = list(image_grid_thw) + [[1, 32, 32]] llm_pos_ids_list: list[torch.Tensor] = [] diff --git a/vllm_omni/model_executor/stage_configs/glm_image.yaml b/vllm_omni/model_executor/stage_configs/glm_image.yaml index 05ac84a7a09..f3ed6c7213d 100644 --- a/vllm_omni/model_executor/stage_configs/glm_image.yaml +++ b/vllm_omni/model_executor/stage_configs/glm_image.yaml @@ -33,7 +33,6 @@ stage_args: temperature: 0.9 # From model's generation_config.json top_p: 0.75 # From model's generation_config.json top_k: 16512 # vision_vocab_size from generation_config.json - max_tokens: 1281 # For 1024x1024: small(16x16=256) + large(32x32=1024) + EOS(1) stop_token_ids: [16385] # eos_token_id from generation_config.json seed: 42 detokenize: false diff --git a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml index 7bd66c403fc..2a85a6dadbc 100644 --- a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml +++ b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml @@ -35,7 +35,6 @@ stage_args: temperature: 0.9 # From model's generation_config.json top_p: 0.75 # From model's generation_config.json top_k: 16512 # vision_vocab_size from generation_config.json - max_tokens: 1281 # For 1024x1024: small(16x16=256) + large(32x32=1024) + EOS(1) stop_token_ids: [16385] # eos_token_id from generation_config.json seed: 42 detokenize: false diff --git a/vllm_omni/model_executor/stage_input_processors/glm_image.py b/vllm_omni/model_executor/stage_input_processors/glm_image.py index 3063620bf8f..53e99610f5c 100644 --- a/vllm_omni/model_executor/stage_input_processors/glm_image.py +++ b/vllm_omni/model_executor/stage_input_processors/glm_image.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Stage input processor for GLM-Image: AR → Diffusion transition.""" +import math +import time from typing import Any import torch @@ -13,6 +15,86 @@ logger = init_logger(__name__) +def _has_source_image(mm_data: Any) -> bool: + """Return whether prompt multi_modal_data contains a source image. + + Normalizes legacy/new keys used across omni pipelines: + - `image`: single PIL image or list + - `img2img`: legacy single-image key + - `images`: list or single image + """ + if not isinstance(mm_data, dict): + return False + if mm_data.get("image") is not None: + return True + if mm_data.get("img2img") is not None: + return True + images = mm_data.get("images") + return bool(images) + + +def _first_source_image(mm_data: Any) -> Any: + """Get first source image from normalized multimodal keys.""" + if not isinstance(mm_data, dict): + return None + + image = mm_data.get("image") + if image is not None: + if isinstance(image, list): + return image[0] if image else None + return image + + image = mm_data.get("img2img") + if image is not None: + if isinstance(image, list): + return image[0] if image else None + return image + + images = mm_data.get("images") + if isinstance(images, list): + return images[0] if images else None + return images + + +def compute_max_tokens(height: int, width: int, factor: int = 32, is_i2i: bool = False) -> int: + """ + Compute max_new_tokens for GLM-Image AR generation. + + GLM-Image generation differs by mode: + + - text-to-image (t2i): small preview + large target + EOS + - image-to-image (i2i): large target + EOS + + Args: + height: Target image height in pixels + width: Target image width in pixels + factor: Downsampling factor (32 for GLM-Image AR output) + is_i2i: Whether the request is image-to-image mode + + Returns: + Total number of tokens to generate for the specified mode + """ + # Large image tokens (target resolution) + token_h = height // factor + token_w = width // factor + large_tokens = token_h * token_w + + # Small preview tokens (half resolution in each dimension) + import math + + ratio = token_h / token_w if token_w > 0 else 1.0 + small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) + small_tokens = small_token_h * small_token_w + + # Mode-dependent totals: + # - t2i: small + large + EOS + # - i2i: large + EOS + if is_i2i: + return large_tokens + 1 + return small_tokens + large_tokens + 1 + + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: """Upsample token IDs by 2x using nearest neighbor interpolation. @@ -56,39 +138,49 @@ def _parse_generated_tokens( large_image_tokens = token_h * token_w # Calculate small preview image dimensions (used in text-to-image) - small_token_h = token_h // 2 - small_token_w = token_w // 2 + ratio = token_h / token_w if token_w > 0 else 1.0 + small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) small_image_tokens = small_token_h * small_token_w token_tensor = torch.tensor(token_ids, dtype=torch.long) # Remove EOS token (16385) from the end if present eos_token_id = 16385 - if len(token_ids) > 0 and token_ids[-1] == eos_token_id: + has_terminal_eos = len(token_ids) > 0 and token_ids[-1] == eos_token_id + if has_terminal_eos: token_tensor = token_tensor[:-1] actual_tokens = len(token_tensor) - logger.debug( - f"[_parse_generated_tokens] height={height}, width={width}, " - f"token_h={token_h}, token_w={token_w}, " - f"large_image_tokens={large_image_tokens}, small_image_tokens={small_image_tokens}, " - f"actual_tokens={actual_tokens}" - ) - if is_i2i: - # Image-to-image mode: check if AR generated small+large tokens (like t2i) or just large tokens - # Some AR models output small+large even in i2i mode if actual_tokens >= small_image_tokens + large_image_tokens: - # AR generated full t2i-style output, extract large tokens after small large_start = small_image_tokens large_end = large_start + large_image_tokens prior_token_ids_d32 = token_tensor[large_start:large_end] actual_h, actual_w = token_h, token_w - else: - # AR generated only large tokens (pure i2i output) + logger.warning( + "[_parse_generated_tokens] i2i detected t2i-style token layout; " + "using small-offset extraction: large_start=%s large_end=%s", + large_start, + large_end, + ) + elif actual_tokens >= large_image_tokens: prior_token_ids_d32 = token_tensor[:large_image_tokens] actual_h, actual_w = token_h, token_w + logger.info( + "[_parse_generated_tokens] i2i using offset-0 extraction: large_tokens=%s", + large_image_tokens, + ) + else: + logger.warning( + "[_parse_generated_tokens] i2i token parse failed: actual_tokens=%s < expected_large_tokens=%s", + actual_tokens, + large_image_tokens, + ) + raise ValueError( + f"i2i token parse failed: actual_tokens={actual_tokens} < expected_large_tokens={large_image_tokens}" + ) elif actual_tokens >= small_image_tokens + large_image_tokens: # Text-to-image: extract large image tokens after small image tokens large_start = small_image_tokens @@ -96,43 +188,22 @@ def _parse_generated_tokens( prior_token_ids_d32 = token_tensor[large_start:large_end] actual_h, actual_w = token_h, token_w elif actual_tokens >= large_image_tokens: - # Image-to-image: large image tokens are at the beginning - prior_token_ids_d32 = token_tensor[:large_image_tokens] - actual_h, actual_w = token_h, token_w + logger.warning( + "[_parse_generated_tokens] t2i token parse failed: got only large tokens without small preview " + "(actual_tokens=%s, expected_small_plus_large=%s)", + actual_tokens, + small_image_tokens + large_image_tokens, + ) + raise ValueError("t2i token parse failed: missing small-preview tokens; refusing low-quality fallback") else: - # Insufficient tokens - try to infer the actual grid size - import math - - for scale in [1, 2, 4]: - test_h = token_h // scale - test_w = token_w // scale - test_small_h = test_h // 2 - test_small_w = test_w // 2 - test_large = test_h * test_w - test_small = test_small_h * test_small_w - - if actual_tokens >= test_small + test_large: - prior_token_ids_d32 = token_tensor[test_small : test_small + test_large] - actual_h, actual_w = test_h, test_w - height = test_h * factor - width = test_w * factor - logger.warning(f"Adjusted grid to {test_h}x{test_w}, output will be {height}x{width}") - break - elif actual_tokens >= test_large: - prior_token_ids_d32 = token_tensor[:test_large] - actual_h, actual_w = test_h, test_w - height = test_h * factor - width = test_w * factor - logger.warning(f"Adjusted grid to {test_h}x{test_w}, output will be {height}x{width}") - break - else: - sqrt_tokens = int(math.sqrt(actual_tokens)) - actual_h = actual_w = sqrt_tokens - usable_tokens = sqrt_tokens * sqrt_tokens - prior_token_ids_d32 = token_tensor[:usable_tokens] - height = sqrt_tokens * factor - width = sqrt_tokens * factor - logger.error(f"Grid pattern mismatch. Using {sqrt_tokens}x{sqrt_tokens}, output: {height}x{width}") + logger.warning( + "[_parse_generated_tokens] token parse failed: insufficient tokens " + "(actual_tokens=%s, expected=%s, mode=%s)", + actual_tokens, + large_image_tokens if is_i2i else (small_image_tokens + large_image_tokens), + "i2i" if is_i2i else "t2i", + ) + raise ValueError(f"token parse failed: actual_tokens={actual_tokens}, mode={'i2i' if is_i2i else 't2i'}") # Upsample from 32x to 16x prior_token_ids = _upsample_token_ids(prior_token_ids_d32, actual_h, actual_w) @@ -147,6 +218,8 @@ def ar2diffusion( requires_multimodal_data: bool = False, ) -> list[dict[str, Any]]: """Process AR stage outputs to create Diffusion stage inputs.""" + _t_total = time.perf_counter() + if not engine_input_source: raise ValueError("engine_input_source cannot be empty") @@ -165,6 +238,7 @@ def ar2diffusion( prompt = [prompt] if prompt is not None else [{}] for i, ar_output in enumerate(ar_outputs): + _t_req = time.perf_counter() output = ar_output.outputs[0] generated_token_ids = output.token_ids @@ -179,23 +253,76 @@ def ar2diffusion( else: original_prompt = {} - height = original_prompt.get("height", 1024) - width = original_prompt.get("width", 1024) + mm_processor_kwargs = original_prompt.get("mm_processor_kwargs") + + def _coerce_dim(v: Any, default: int) -> int: + try: + iv = int(v) + return iv if iv > 0 else default + except (TypeError, ValueError): + return default + + # Prefer GLM-Image target size from mm_processor_kwargs (set by serving layer), + # then fall back to top-level fields for backward compatibility. + height = _coerce_dim( + mm_processor_kwargs.get("target_h") if isinstance(mm_processor_kwargs, dict) else None, + _coerce_dim(original_prompt.get("height"), 1024), + ) + width = _coerce_dim( + mm_processor_kwargs.get("target_w") if isinstance(mm_processor_kwargs, dict) else None, + _coerce_dim(original_prompt.get("width"), 1024), + ) text_prompt = original_prompt.get("prompt", "") - # Detect i2i mode first by checking if multimodal_output contains prior_token_image_ids + # Detect i2i mode. + # Prefer normalized prompt multi_modal_data source-image presence, with + # multimodal output as secondary signal. + _t_mode = time.perf_counter() is_i2i = False + + prompt_modalities = original_prompt.get("modalities") + if isinstance(prompt_modalities, list) and "img2img" in prompt_modalities: + is_i2i = True + + prompt_mm_data = original_prompt.get("multi_modal_data") + if _has_source_image(prompt_mm_data): + is_i2i = True + if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output: mm_output = ar_output.multimodal_output - if isinstance(mm_output, dict) and mm_output.get("prior_token_image_ids") is not None: - is_i2i = True + if isinstance(mm_output, dict): + if mm_output.get("prior_token_image_ids") is not None: + is_i2i = True + _dt_mode = (time.perf_counter() - _t_mode) * 1000 # Parse and upsample prior tokens - prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens(generated_token_ids, height, width, is_i2i=is_i2i) + _t_parse = time.perf_counter() + try: + prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens( + generated_token_ids, + height, + width, + is_i2i=is_i2i, + ) + except ValueError as e: + logger.warning( + "[ar2diffusion] Request %s: skip due to token parse failure: %s " + "(target=%sx%s, mode=%s, raw_tokens=%s, tail=%s)", + i, + e, + height, + width, + "i2i" if is_i2i else "t2i", + len(generated_token_ids), + generated_token_ids[-8:] if len(generated_token_ids) >= 8 else generated_token_ids, + ) + continue + _dt_parse = (time.perf_counter() - _t_parse) * 1000 # Get prior_token_image_ids from AR model output (for i2i mode) # This contains VQ-VAE tokens from input image, used for KV cache conditioning # NOTE: multimodal_output is attached to ar_output (RequestOutput), NOT output (CompletionOutput) + _t_prior_img = time.perf_counter() prior_token_image_ids = None # Check ar_output (RequestOutput) for multimodal_output - this is the correct location @@ -234,6 +361,7 @@ def ar2diffusion( prior_token_image_ids = [raw_prior_image_ids] elif isinstance(raw_prior_image_ids, list): prior_token_image_ids = raw_prior_image_ids + _dt_prior_img = (time.perf_counter() - _t_prior_img) * 1000 diffusion_input = { "prompt": text_prompt, @@ -248,18 +376,38 @@ def ar2diffusion( if requires_multimodal_data: mm_data = original_prompt.get("multi_modal_data") if mm_data: - pil_image = mm_data.get("image") - if pil_image is None: - # Try "images" (plural) as fallback - images = mm_data.get("images") - if images: - pil_image = images[0] if isinstance(images, list) else images + pil_image = _first_source_image(mm_data) diffusion_input["pil_image"] = pil_image for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]: if key in original_prompt: diffusion_input[key] = original_prompt[key] + _dt_req = (time.perf_counter() - _t_req) * 1000 + logger.info( + "[ar2diffusion] req=%d mode=%s target=%dx%d " + "raw_tokens=%d prior_tokens=%d prior_image_ids=%s " + "timing: mode_detect=%.3fms parse+upsample=%.3fms " + "prior_image_ids_extract=%.3fms req_total=%.3fms", + i, + "i2i" if is_i2i else "t2i", + pixel_h, + pixel_w, + len(generated_token_ids), + len(prior_token_ids), + "yes" if prior_token_image_ids is not None else "no", + _dt_mode, + _dt_parse, + _dt_prior_img, + _dt_req, + ) diffusion_inputs.append(diffusion_input) + _dt_total = (time.perf_counter() - _t_total) * 1000 + logger.info( + "[ar2diffusion] batch done: %d reqs, total=%.3fms", + len(diffusion_inputs), + _dt_total, + ) + return diffusion_inputs diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index d1c15eac640..77f487725ec 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -152,8 +152,10 @@ def _init_mrope_positions(self, req_state: CachedRequestState): if supports_mrope(self.get_model()): # Model implements SupportsMRoPE interface # Pass all extracted metadata; models use what they need via **kwargs - req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( - req_state.prompt_token_ids, + sp_extra_args = getattr(req_state.sampling_params, "extra_args", {}) if req_state.sampling_params else {} + target_h = sp_extra_args.get("target_h") if isinstance(sp_extra_args, dict) else None + target_w = sp_extra_args.get("target_w") if isinstance(sp_extra_args, dict) else None + kwargs = dict( mm_features=req_state.mm_features, hf_config=self.model_config.hf_config, image_grid_thw=image_grid_thw, @@ -162,6 +164,14 @@ def _init_mrope_positions(self, req_state: CachedRequestState): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + if target_h is not None: + kwargs["target_h"] = target_h + if target_w is not None: + kwargs["target_w"] = target_w + req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + **kwargs, + ) else: req_state.mrope_positions, req_state.mrope_position_delta = MRotaryEmbedding.get_input_positions_tensor( req_state.prompt_token_ids,