diff --git a/docs/configuration/stage_configs.md b/docs/configuration/stage_configs.md index 55b4053cc71..546e6408a68 100644 --- a/docs/configuration/stage_configs.md +++ b/docs/configuration/stage_configs.md @@ -279,6 +279,20 @@ A custom Python function hook for the LLM stage (Stage 0) that expands a single A custom Python function hook for downstream diffusion stages (Stage 1+) to collect, map, and process the KV caches transferred from the companion requests fired by `prompt_expand_func`. It aggregates the hidden condition states cleanly (e.g., binding them as `cfg_text_past_key_values` and `cfg_text_kv_metadata`), allowing the diffusion runtime to perform CFG smoothly without redundantly evaluating text paths on the DiT workers. +### `prompt_preprocess_func` (Optional) + +A custom Python function hook that transforms the raw user prompt before Stage 0 tokenization. This enables server-side chat template application, system prompt prepending, or other prompt formatting without requiring client-side changes. + +The function receives the prompt (either a dict with a "prompt" key from the API endpoint, or a raw string) and returns the transformed prompt in the same format. + +Example use case: a multi-stage image generation pipeline where Stage 0 is a prompt rewriting LLM (e.g., GLM-Image) that needs chat-formatted input, but the client sends raw text prompts via `/v1/images/generations`. + +The function is specified as a dotted Python path in the YAML config: + + prompt_preprocess_func: vllm_omni.model_executor.stage_input_processors.glm_image.preprocess_prompt_for_glm + +The preprocessor runs in `_build_add_request_message()` BEFORE `InputProcessor.process_inputs()` tokenizes the prompt. Only one `prompt_preprocess_func` is active at a time (from the last stage that defines it), matching the behavior of `prompt_expand_func`. + ### `runtime` Configuration for disaggregated execution of the stage, controlling how the stage is deployed and executed. diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md index 1eaff10596c..ba1d5d95cb2 100644 --- a/docs/contributing/model/adding_omni_model.md +++ b/docs/contributing/model/adding_omni_model.md @@ -315,6 +315,12 @@ The registry uses lazy loading, so the model class is imported only when needed. Create a YAML configuration file in `vllm_omni/deploy/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml). +#### Optional: Server-Side Prompt Preprocessing + +If your model's AR stage requires a specific prompt format (e.g., chat template with system prompt), you can use `prompt_preprocess_func` to apply the formatting server-side. This keeps the client API clean - clients send raw prompts without needing to know the model's internal formatting requirements. + +See `docs/configuration/stage_configs.md` for the `prompt_preprocess_func` reference. + ### Key Configuration Fields - **`model_stage`**: Which stage to run ("thinker", "talker", "code2wav", etc.) diff --git a/tests/engine/test_prompt_preprocess_func.py b/tests/engine/test_prompt_preprocess_func.py new file mode 100644 index 00000000000..9a8a5c75044 --- /dev/null +++ b/tests/engine/test_prompt_preprocess_func.py @@ -0,0 +1,164 @@ +"""Tests for prompt_preprocess_func stage config field.""" + +import types + +import pytest + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# --------------------------------------------------------------------------- +# A trivial preprocessor used by tests (avoids depending on any model code). +# --------------------------------------------------------------------------- + + +def _identity_preprocess(prompt): + """No-op preprocessor for testing.""" + return prompt + + +# --------------------------------------------------------------------------- +# extract_stage_metadata loading tests +# --------------------------------------------------------------------------- + + +def _make_llm_stage_config(prompt_preprocess_func=None): + """Build a minimal LLM stage config namespace for extract_stage_metadata.""" + engine_args = { + "model_stage": "ar", + "engine_output_type": "token_ids", + } + cfg = types.SimpleNamespace( + stage_id=0, + stage_type="llm", + engine_args=engine_args, + runtime={}, + engine_input_source=[], + final_output=False, + final_output_type=None, + default_sampling_params={}, + is_comprehension=True, + ) + if prompt_preprocess_func is not None: + cfg.prompt_preprocess_func = prompt_preprocess_func + # Ensure other optional func fields exist so getattr doesn't fall through + cfg.prompt_expand_func = None + cfg.cfg_kv_collect_func = None + cfg.custom_process_input_func = None + return cfg + + +def test_prompt_preprocess_func_loaded_from_config(): + """Verify prompt_preprocess_func is resolved from a dotted path.""" + from vllm_omni.engine.stage_init_utils import extract_stage_metadata + + # Point at a known built-in function to verify importlib resolution. + stage_config = _make_llm_stage_config( + prompt_preprocess_func="copy.copy", + ) + metadata = extract_stage_metadata(stage_config) + assert metadata.prompt_preprocess_func is not None + assert callable(metadata.prompt_preprocess_func) + + +def test_prompt_preprocess_func_none_when_not_configured(): + """Backward compat: missing field results in None.""" + from vllm_omni.engine.stage_init_utils import extract_stage_metadata + + stage_config = _make_llm_stage_config(prompt_preprocess_func=None) + metadata = extract_stage_metadata(stage_config) + assert metadata.prompt_preprocess_func is None + + +def test_prompt_preprocess_func_none_when_attr_missing(): + """Backward compat: attribute not present at all results in None.""" + from vllm_omni.engine.stage_init_utils import extract_stage_metadata + + stage_config = _make_llm_stage_config() + # Remove the attribute entirely + if hasattr(stage_config, "prompt_preprocess_func"): + delattr(stage_config, "prompt_preprocess_func") + metadata = extract_stage_metadata(stage_config) + assert metadata.prompt_preprocess_func is None + + +# --------------------------------------------------------------------------- +# _initialize_stages collects prompt_preprocess_func +# --------------------------------------------------------------------------- + + +def test_initialize_stages_collects_prompt_preprocess_func(monkeypatch): + """Verify _initialize_stages stores prompt_preprocess_func on self.""" + import vllm_omni.engine.async_omni_engine as engine_mod + from vllm_omni.engine.async_omni_engine import AsyncOmniEngine + from vllm_omni.platforms import current_omni_platform + + engine = object.__new__(AsyncOmniEngine) + engine.model = "dummy-model" + engine.config_path = "dummy-config" + engine.num_stages = 1 + engine.async_chunk = False + engine.diffusion_batch_size = 1 + engine.single_stage_mode = False + engine._single_stage_id_filter = None + engine._omni_master_server = None + engine.stage_configs = [ + types.SimpleNamespace(stage_id=0, stage_type="diffusion"), + ] + + env_var = current_omni_platform.device_control_env_var + old_env = __import__("os").environ.get(env_var) + __import__("os").environ[env_var] = "0" + + _sentinel = lambda p: p # noqa: E731 + + metadata = types.SimpleNamespace( + stage_id=0, + stage_type="diffusion", + runtime_cfg={"devices": "0"}, + prompt_expand_func=None, + prompt_preprocess_func=_sentinel, + ) + + monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None) + monkeypatch.setattr( + engine_mod, + "load_omni_transfer_config_for_model", + lambda *_: None, + ) + monkeypatch.setattr( + engine_mod, + "extract_stage_metadata", + lambda _cfg: metadata, + ) + monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {}) + monkeypatch.setattr( + engine_mod, + "resolve_omni_kv_config_for_stage", + lambda *_: (None, None, None), + ) + monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None) + monkeypatch.setattr(engine_mod, "inject_kv_stage_info", lambda *_: None) + monkeypatch.setattr( + engine_mod, + "initialize_diffusion_stage", + lambda *_, **__: types.SimpleNamespace(is_comprehension=False), + ) + monkeypatch.setattr( + engine_mod, + "finalize_initialized_stages", + lambda stage_clients, _ip: ( + stage_clients, + [types.SimpleNamespace()], + [{"final_output_type": "image", "stage_type": "diffusion"}], + ), + ) + + try: + engine._initialize_stages(stage_init_timeout=1) + assert engine.prompt_preprocess_func is _sentinel + finally: + if old_env is None: + __import__("os").environ.pop(env_var, None) + else: + __import__("os").environ[env_var] = old_env diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 61da4388be0..1bb2955428e 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -692,6 +692,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: async_chunk = self.async_chunk prompt_expand_func = None + prompt_preprocess_func = None + llm_stage_count = sum( 1 for stage_cfg in self.stage_configs if getattr(stage_cfg, "stage_type", "llm") != "diffusion" ) @@ -742,6 +744,9 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: if metadata.prompt_expand_func is not None: prompt_expand_func = metadata.prompt_expand_func + if metadata.prompt_preprocess_func is not None: + prompt_preprocess_func = metadata.prompt_preprocess_func + if self.single_stage_mode: metadata.runtime_cfg = None @@ -888,6 +893,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: self.stage_vllm_configs = stage_vllm_configs self.input_processor = input_processor self.prompt_expand_func = prompt_expand_func + self.prompt_preprocess_func = prompt_preprocess_func + # TODO(Peiqi): Hack here supported_tasks: set[str] = set() if any(getattr(stage_client, "is_comprehension", False) for stage_client in initialized_stage_clients): @@ -1030,6 +1037,10 @@ def _build_add_request_message( stage_type = self.stage_metadata[0].get("stage_type") if stage_type != "diffusion" and not isinstance(prompt, EngineCoreRequest): + # Apply server-side prompt preprocessing before tokenization. + if self.prompt_preprocess_func is not None: + prompt = self.prompt_preprocess_func(prompt) + # Inject global_request_id into the raw prompt. if isinstance(prompt, dict): _inject_global_id(prompt, request_id) diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index cc7676ba5d4..1b977bad0ac 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -259,6 +259,7 @@ class StageMetadata: runtime_cfg: Any prompt_expand_func: Callable | None = None cfg_kv_collect_func: Callable | None = None + prompt_preprocess_func: Callable | None = None @dataclass @@ -321,6 +322,12 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: _mod, _fn = _ckf_path.rsplit(".", 1) cfg_kv_collect_func = getattr(importlib.import_module(_mod), _fn) + prompt_preprocess_func: Callable | None = None + _ppf_path = getattr(stage_config, "prompt_preprocess_func", None) + if _ppf_path: + _mod, _fn = _ppf_path.rsplit(".", 1) + prompt_preprocess_func = getattr(importlib.import_module(_mod), _fn) + if stage_type == "diffusion": return StageMetadata( stage_id=stage_id, @@ -357,6 +364,7 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: model_stage=model_stage, runtime_cfg=runtime_cfg, prompt_expand_func=prompt_expand_func, + prompt_preprocess_func=prompt_preprocess_func, )