diff --git a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py index 0a9d45a8ccb..b91f22e9bfb 100644 --- a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py +++ b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py @@ -697,3 +697,59 @@ def test_size_string_parsed_for_glm_image( result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req) # 512x512 t2i = 256 + 256 + 1 = 513 assert result.max_tokens == 513 + + def test_falls_back_to_diffusion_stage_defaults_when_no_extra_body( + self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params + ): + """No extra_body → serving_chat pulls h/w from any stage's default + sampling params. Simulates the recipe's bare-curl: GLM-Image stage-1 + yaml declares height=1024, width=1024 which feeds the AR max_tokens + compute so the AR doesn't fall through to vLLM's max_model_len. + """ + from types import SimpleNamespace + + diffusion_defaults = SimpleNamespace(height=1024, width=1024) + glm_serving_chat.engine_client.default_sampling_params_list = [ + default_comprehension_params, + diffusion_defaults, + ] + + req = mocker.MagicMock() + for f in ( + "temperature", + "top_p", + "top_k", + "max_tokens", + "min_tokens", + "seed", + "ignore_eos", + "stop", + "stop_token_ids", + "frequency_penalty", + "presence_penalty", + ): + setattr(req, f, None) + req.extra_body = {} + req.model_fields_set = set() + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req) + # t2i 1024x1024 = 256 + 1024 + 1 = 1281 + assert result.max_tokens == 1281 + assert result.extra_args["target_h"] == 1024 + assert result.extra_args["target_w"] == 1024 + + def test_explicit_null_max_tokens_still_computes(self, glm_serving_chat, glm_request, default_comprehension_params): + """Sending ``max_tokens=None`` explicitly (Pydantic adds ``max_tokens`` + to ``model_fields_set`` even when the value is None) must not suppress + the compute. The field-copy loop drops None values, so the compute + must still populate ``params.max_tokens`` from the target size — + otherwise ``max_tokens`` stays unset and falls through to vLLM's + ``max_model_len - seq_len`` default, reintroducing the original + IndexError. + """ + glm_request.max_tokens = None + glm_request.model_fields_set = {"max_tokens"} + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request) + # t2i 1024x1024 = 1281; must override the null the user sent + assert result.max_tokens == 1281 diff --git a/tests/inputs/test_preprocess.py b/tests/inputs/test_preprocess.py new file mode 100644 index 00000000000..db8817cf2d1 --- /dev/null +++ b/tests/inputs/test_preprocess.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for OmniInputPreprocessor._process_text routing.""" + +import pytest +from pytest_mock import MockerFixture + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class TestProcessTextMmProcessorKwargsRouting: + """Presence-based routing: an explicit empty ``mm_processor_kwargs`` dict + on the prompt still routes through ``_process_multimodal``. Required for + AR-based image generation (e.g. GLM-Image t2i) where the HF processor + supplies the image-generation scaffold from its own defaults and no extra + kwargs are needed from the caller. + """ + + @pytest.fixture + def preprocessor(self, mocker: MockerFixture): + from vllm_omni.inputs.preprocess import OmniInputPreprocessor + + instance = object.__new__(OmniInputPreprocessor) + instance._process_multimodal = mocker.MagicMock(return_value={}) + instance._tokenize_prompt = mocker.MagicMock(return_value=[1, 2, 3]) + return instance + + def test_empty_mm_processor_kwargs_routes_to_multimodal(self, preprocessor): + preprocessor._process_text({"prompt": "hello", "mm_processor_kwargs": {}}) + assert preprocessor._process_multimodal.called + assert not preprocessor._tokenize_prompt.called + + def test_missing_mm_processor_kwargs_routes_to_tokenize(self, preprocessor): + preprocessor._process_text({"prompt": "hello"}) + assert preprocessor._tokenize_prompt.called + assert not preprocessor._process_multimodal.called diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index fb96b397eb7..4ca3775be1d 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -351,14 +351,18 @@ async def create_chat_completion( tprompt["modalities"] = ["image"] if negative_prompt is not None: tprompt["negative_prompt"] = negative_prompt - # GLM-Image's _call_hf_processor expects target_h/target_w in mm_processor_kwargs + # Always attach mm_processor_kwargs (possibly empty) so + # OmniInputPreprocessor._process_text routes through the + # multimodal processor path. Without it, the preprocessor + # falls back to plain _tokenize_prompt and AR-based image-gen + # models like GLM-Image never see their image-generation + # scaffold. mm_processor_kwargs: dict[str, Any] = {} if height is not None: mm_processor_kwargs["target_h"] = height if width is not None: mm_processor_kwargs["target_w"] = width - if mm_processor_kwargs: - tprompt["mm_processor_kwargs"] = mm_processor_kwargs + tprompt["mm_processor_kwargs"] = mm_processor_kwargs if engine_prompt_image is not None: tprompt["multi_modal_data"] = engine_prompt_image # Provide multi_modal_uuids so that newer vLLM versions @@ -736,6 +740,22 @@ def _apply_request_overrides( extra_body = getattr(request, "extra_body", {}) or {} height, width = self._resolve_height_width_from_extra_body(extra_body) + # Fall back to the diffusion stage's default h/w when the user didn't + # specify them, so the compute works for the bare-curl request shape + # (no extra_body). Implicit gate: only fires when a stage in the + # pipeline declares height/width in its sampling params (e.g. GLM-Image + # stage-1 yaml); LLM-only / audio pipelines have neither and are skipped. + if height is None or width is None: + for dp in self.engine_client.default_sampling_params_list or []: + stage_h = getattr(dp, "height", None) + stage_w = getattr(dp, "width", None) + if stage_h is not None and stage_w is not None: + if height is None: + height = stage_h + if width is None: + width = stage_w + break + # Best-effort mode detection from user messages. # i2i requests include at least one reference image in message content. _, reference_images = self._extract_diffusion_prompt_and_images_from_messages(request.messages) @@ -757,15 +777,7 @@ def _apply_request_overrides( extra_args["target_w"] = int(width) params.extra_args = extra_args except (ImportError, ValueError, TypeError) as e: - logger.warning(f"Failed to compute max_tokens: {e}, using default if available") - else: - logger.info( - "[SamplingParams] Skip dynamic max_tokens (height=%s, width=%s, mode=%s, ref_images=%s)", - height, - width, - "i2i" if is_img2img else "t2i", - ref_image_count, - ) + logger.warning("Failed to compute max_tokens: %s", e) return params diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index cca6ce56870..7282d7a520d 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -60,7 +60,11 @@ def _process_text( additional_information = parsed_content.get("additional_information") if additional_information is not None: inputs["additional_information"] = additional_information - elif mm_processor_kwargs: + elif "mm_processor_kwargs" in parsed_content: + # Presence — not truthiness. An explicitly-set empty dict still + # signals "route through the multimodal processor" (needed for + # AR-based image-gen where the HF processor supplies its own + # defaults and scaffold). inputs = self._process_multimodal( prompt_text, {},