Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/configuration/stage_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,20 @@ A custom Python function hook for the LLM stage (Stage 0) that expands a single

A custom Python function hook for downstream diffusion stages (Stage 1+) to collect, map, and process the KV caches transferred from the companion requests fired by `prompt_expand_func`. It aggregates the hidden condition states cleanly (e.g., binding them as `cfg_text_past_key_values` and `cfg_text_kv_metadata`), allowing the diffusion runtime to perform CFG smoothly without redundantly evaluating text paths on the DiT workers.

### `prompt_preprocess_func` (Optional)

A custom Python function hook that transforms the raw user prompt before Stage 0 tokenization. This enables server-side chat template application, system prompt prepending, or other prompt formatting without requiring client-side changes.

The function receives the prompt (either a dict with a "prompt" key from the API endpoint, or a raw string) and returns the transformed prompt in the same format.

Example use case: a multi-stage image generation pipeline where Stage 0 is a prompt rewriting LLM (e.g., GLM-Image) that needs chat-formatted input, but the client sends raw text prompts via `/v1/images/generations`.

The function is specified as a dotted Python path in the YAML config:

prompt_preprocess_func: vllm_omni.model_executor.stage_input_processors.glm_image.preprocess_prompt_for_glm

The preprocessor runs in `_build_add_request_message()` BEFORE `InputProcessor.process_inputs()` tokenizes the prompt. Only one `prompt_preprocess_func` is active at a time (from the last stage that defines it), matching the behavior of `prompt_expand_func`.

### `runtime`

Configuration for disaggregated execution of the stage, controlling how the stage is deployed and executed.
Expand Down
6 changes: 6 additions & 0 deletions docs/contributing/model/adding_omni_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ The registry uses lazy loading, so the model class is imported only when needed.

Create a YAML configuration file in `vllm_omni/deploy/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml).

#### Optional: Server-Side Prompt Preprocessing

If your model's AR stage requires a specific prompt format (e.g., chat template with system prompt), you can use `prompt_preprocess_func` to apply the formatting server-side. This keeps the client API clean - clients send raw prompts without needing to know the model's internal formatting requirements.

See `docs/configuration/stage_configs.md` for the `prompt_preprocess_func` reference.

### Key Configuration Fields

- **`model_stage`**: Which stage to run ("thinker", "talker", "code2wav", etc.)
Expand Down
164 changes: 164 additions & 0 deletions tests/engine/test_prompt_preprocess_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Tests for prompt_preprocess_func stage config field."""

import types

import pytest

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


# ---------------------------------------------------------------------------
# A trivial preprocessor used by tests (avoids depending on any model code).
# ---------------------------------------------------------------------------


def _identity_preprocess(prompt):
"""No-op preprocessor for testing."""
return prompt


# ---------------------------------------------------------------------------
# extract_stage_metadata loading tests
# ---------------------------------------------------------------------------


def _make_llm_stage_config(prompt_preprocess_func=None):
"""Build a minimal LLM stage config namespace for extract_stage_metadata."""
engine_args = {
"model_stage": "ar",
"engine_output_type": "token_ids",
}
cfg = types.SimpleNamespace(
stage_id=0,
stage_type="llm",
engine_args=engine_args,
runtime={},
engine_input_source=[],
final_output=False,
final_output_type=None,
default_sampling_params={},
is_comprehension=True,
)
if prompt_preprocess_func is not None:
cfg.prompt_preprocess_func = prompt_preprocess_func
# Ensure other optional func fields exist so getattr doesn't fall through
cfg.prompt_expand_func = None
cfg.cfg_kv_collect_func = None
cfg.custom_process_input_func = None
return cfg


def test_prompt_preprocess_func_loaded_from_config():
"""Verify prompt_preprocess_func is resolved from a dotted path."""
from vllm_omni.engine.stage_init_utils import extract_stage_metadata

# Point at a known built-in function to verify importlib resolution.
stage_config = _make_llm_stage_config(
prompt_preprocess_func="copy.copy",
)
metadata = extract_stage_metadata(stage_config)
assert metadata.prompt_preprocess_func is not None
assert callable(metadata.prompt_preprocess_func)


def test_prompt_preprocess_func_none_when_not_configured():
"""Backward compat: missing field results in None."""
from vllm_omni.engine.stage_init_utils import extract_stage_metadata

stage_config = _make_llm_stage_config(prompt_preprocess_func=None)
metadata = extract_stage_metadata(stage_config)
assert metadata.prompt_preprocess_func is None


def test_prompt_preprocess_func_none_when_attr_missing():
"""Backward compat: attribute not present at all results in None."""
from vllm_omni.engine.stage_init_utils import extract_stage_metadata

stage_config = _make_llm_stage_config()
# Remove the attribute entirely
if hasattr(stage_config, "prompt_preprocess_func"):
delattr(stage_config, "prompt_preprocess_func")
metadata = extract_stage_metadata(stage_config)
assert metadata.prompt_preprocess_func is None


# ---------------------------------------------------------------------------
# _initialize_stages collects prompt_preprocess_func
# ---------------------------------------------------------------------------


def test_initialize_stages_collects_prompt_preprocess_func(monkeypatch):
"""Verify _initialize_stages stores prompt_preprocess_func on self."""
import vllm_omni.engine.async_omni_engine as engine_mod
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
from vllm_omni.platforms import current_omni_platform

engine = object.__new__(AsyncOmniEngine)
engine.model = "dummy-model"
engine.config_path = "dummy-config"
engine.num_stages = 1
engine.async_chunk = False
engine.diffusion_batch_size = 1
engine.single_stage_mode = False
engine._single_stage_id_filter = None
engine._omni_master_server = None
engine.stage_configs = [
types.SimpleNamespace(stage_id=0, stage_type="diffusion"),
]

env_var = current_omni_platform.device_control_env_var
old_env = __import__("os").environ.get(env_var)
__import__("os").environ[env_var] = "0"

_sentinel = lambda p: p # noqa: E731

metadata = types.SimpleNamespace(
stage_id=0,
stage_type="diffusion",
runtime_cfg={"devices": "0"},
prompt_expand_func=None,
prompt_preprocess_func=_sentinel,
)

monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None)
monkeypatch.setattr(
engine_mod,
"load_omni_transfer_config_for_model",
lambda *_: None,
)
monkeypatch.setattr(
engine_mod,
"extract_stage_metadata",
lambda _cfg: metadata,
)
monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {})
monkeypatch.setattr(
engine_mod,
"resolve_omni_kv_config_for_stage",
lambda *_: (None, None, None),
)
monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
monkeypatch.setattr(engine_mod, "inject_kv_stage_info", lambda *_: None)
monkeypatch.setattr(
engine_mod,
"initialize_diffusion_stage",
lambda *_, **__: types.SimpleNamespace(is_comprehension=False),
)
monkeypatch.setattr(
engine_mod,
"finalize_initialized_stages",
lambda stage_clients, _ip: (
stage_clients,
[types.SimpleNamespace()],
[{"final_output_type": "image", "stage_type": "diffusion"}],
),
)

try:
engine._initialize_stages(stage_init_timeout=1)
assert engine.prompt_preprocess_func is _sentinel
finally:
if old_env is None:
__import__("os").environ.pop(env_var, None)
else:
__import__("os").environ[env_var] = old_env
11 changes: 11 additions & 0 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:

async_chunk = self.async_chunk
prompt_expand_func = None
prompt_preprocess_func = None

llm_stage_count = sum(
1 for stage_cfg in self.stage_configs if getattr(stage_cfg, "stage_type", "llm") != "diffusion"
)
Expand Down Expand Up @@ -742,6 +744,9 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
if metadata.prompt_expand_func is not None:
prompt_expand_func = metadata.prompt_expand_func

if metadata.prompt_preprocess_func is not None:
prompt_preprocess_func = metadata.prompt_preprocess_func

if self.single_stage_mode:
metadata.runtime_cfg = None

Expand Down Expand Up @@ -888,6 +893,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
self.stage_vllm_configs = stage_vllm_configs
self.input_processor = input_processor
self.prompt_expand_func = prompt_expand_func
self.prompt_preprocess_func = prompt_preprocess_func

# TODO(Peiqi): Hack here
supported_tasks: set[str] = set()
if any(getattr(stage_client, "is_comprehension", False) for stage_client in initialized_stage_clients):
Expand Down Expand Up @@ -1030,6 +1037,10 @@ def _build_add_request_message(

stage_type = self.stage_metadata[0].get("stage_type")
if stage_type != "diffusion" and not isinstance(prompt, EngineCoreRequest):
# Apply server-side prompt preprocessing before tokenization.
if self.prompt_preprocess_func is not None:
prompt = self.prompt_preprocess_func(prompt)

# Inject global_request_id into the raw prompt.
if isinstance(prompt, dict):
_inject_global_id(prompt, request_id)
Expand Down
8 changes: 8 additions & 0 deletions vllm_omni/engine/stage_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class StageMetadata:
runtime_cfg: Any
prompt_expand_func: Callable | None = None
cfg_kv_collect_func: Callable | None = None
prompt_preprocess_func: Callable | None = None


@dataclass
Expand Down Expand Up @@ -321,6 +322,12 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
_mod, _fn = _ckf_path.rsplit(".", 1)
cfg_kv_collect_func = getattr(importlib.import_module(_mod), _fn)

prompt_preprocess_func: Callable | None = None
_ppf_path = getattr(stage_config, "prompt_preprocess_func", None)
if _ppf_path:
_mod, _fn = _ppf_path.rsplit(".", 1)
prompt_preprocess_func = getattr(importlib.import_module(_mod), _fn)

if stage_type == "diffusion":
return StageMetadata(
stage_id=stage_id,
Expand Down Expand Up @@ -357,6 +364,7 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
model_stage=model_stage,
runtime_cfg=runtime_cfg,
prompt_expand_func=prompt_expand_func,
prompt_preprocess_func=prompt_preprocess_func,
)


Expand Down
Loading