From a5a4998ece6374a76469ec3ac22e7b55c95e4e84 Mon Sep 17 00:00:00 2001 From: "Bingyu (Spencer) Liu" Date: Fri, 17 Apr 2026 18:43:17 +0800 Subject: [PATCH 01/38] [Feature] Support Prefill-Decode disaggregation via vLLM KV transfer (#2220) Signed-off-by: LiuBingyu --- tests/e2e/online_serving/test_qwen3_omni.py | 36 +- tests/entrypoints/test_pd_disaggregation.py | 1222 +++++++++++++++++ vllm_omni/engine/async_omni_engine.py | 46 + vllm_omni/engine/orchestrator.py | 132 +- vllm_omni/engine/output_processor.py | 5 +- vllm_omni/entrypoints/async_omni.py | 19 +- vllm_omni/entrypoints/omni.py | 16 +- vllm_omni/entrypoints/omni_base.py | 12 +- vllm_omni/entrypoints/pd_utils.py | 57 +- .../models/qwen3_omni/qwen3_omni.py | 32 +- .../stage_input_processors/qwen3_omni.py | 143 +- 11 files changed, 1672 insertions(+), 48 deletions(-) create mode 100644 tests/entrypoints/test_pd_disaggregation.py diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index 13af2ad1109..9737fa42bdb 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -26,9 +26,15 @@ QWEN3_OMNI_CONFIG_PATH = str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml") QWEN3_OMNI_XPU_CONFIG_PATH = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml") +_STAGE_CONFIGS_DIR = Path(__file__).parent.parent / "stage_configs" +_PD_SEP_CONFIG = str(_STAGE_CONFIGS_DIR / "qwen3_omni_moe_pd_ci.yaml") -def get_chunk_config(config_path: str): - path = modify_stage_config( + +def get_chunk_config(config_path: str | None = None): + """Load qwen3_omni_ci.yaml with async_chunk modifications for streaming mode.""" + if config_path is None: + config_path = str(_STAGE_CONFIGS_DIR / "qwen3_omni_ci.yaml") + return modify_stage_config( config_path, updates={ "async_chunk": True, @@ -43,7 +49,6 @@ def get_chunk_config(config_path: str): }, deletes={"stage_args": {2: ["custom_process_input_func"]}}, ) - return path def get_prefix_caching_config(config_path: str): @@ -59,11 +64,20 @@ def get_prefix_caching_config(config_path: str): return path -if current_omni_platform.is_xpu(): - stage_configs = [QWEN3_OMNI_XPU_CONFIG_PATH] - prefix_caching_stage_configs = [get_prefix_caching_config(QWEN3_OMNI_XPU_CONFIG_PATH)] -else: # MI325 GPU should share the same config as H100 - stage_configs = [get_chunk_config(QWEN3_OMNI_CONFIG_PATH)] +# Set VLLM_TEST_PD_MODE=1 to test PD disaggregation, default tests async_chunk mode. +_USE_PD = os.environ.get("VLLM_TEST_PD_MODE", "0") == "1" + +# Stage configs for H100/CUDA, ROCm MI325, and XPU platforms +if current_omni_platform.is_rocm(): + rocm_config = str(_STAGE_CONFIGS_DIR / "rocm" / "qwen3_omni_ci.yaml") + stage_configs = [rocm_config] + prefix_caching_stage_configs = [get_prefix_caching_config(rocm_config)] +elif current_omni_platform.is_xpu(): + xpu_config = str(_STAGE_CONFIGS_DIR / "xpu" / "qwen3_omni_ci.yaml") + stage_configs = [xpu_config] + prefix_caching_stage_configs = [get_prefix_caching_config(xpu_config)] +else: + stage_configs = [_PD_SEP_CONFIG if _USE_PD else get_chunk_config(QWEN3_OMNI_CONFIG_PATH)] prefix_caching_stage_configs = [get_prefix_caching_config(QWEN3_OMNI_CONFIG_PATH)] # Create parameter combinations for model and stage config @@ -116,7 +130,8 @@ def get_max_batch_size(size_type="few"): @pytest.mark.advanced_model @pytest.mark.core_model @pytest.mark.omni -@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.") +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2) @pytest.mark.parametrize("omni_server", test_params, indirect=True) def test_mix_to_text_audio_001(omni_server, openai_client) -> None: """ @@ -155,7 +170,8 @@ def test_mix_to_text_audio_001(omni_server, openai_client) -> None: @pytest.mark.advanced_model @pytest.mark.core_model @pytest.mark.omni -@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.") +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2) @pytest.mark.parametrize("omni_server", test_params, indirect=True) def test_text_to_text_001(omni_server, openai_client) -> None: """ diff --git a/tests/entrypoints/test_pd_disaggregation.py b/tests/entrypoints/test_pd_disaggregation.py new file mode 100644 index 00000000000..5ffabfbf2af --- /dev/null +++ b/tests/entrypoints/test_pd_disaggregation.py @@ -0,0 +1,1222 @@ +"""Unit tests for PD (Prefill-Decode) disaggregation in the Omni orchestrator. + +Tests the PD detection, validation, config parsing, sampling param +preparation, and routing logic added by the PD disaggregation feature +(issue #1188). All tests run without GPU. + +NOTE (v1908 adaptation): Tests that relied on the old OmniStage / stage_list +architecture (removed in PR #1908) are marked xfail with +``reason="Requires migration to v1908 Orchestrator architecture"``. +The remaining tests exercise PDDisaggregationMixin directly and work +without spinning up a real engine. +""" + +import uuid +import warnings +from queue import Empty, Queue +from types import SimpleNamespace +from typing import Any + +import pytest +from vllm import SamplingParams + +from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin + +pytestmark = pytest.mark.skip(reason="Temporarily skip PD entrypoint tests while PD config is being removed.") + +# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. +warnings.filterwarnings( + "ignore", + message=r"builtin type SwigPy.*has no __module__ attribute", + category=DeprecationWarning, +) + + +def _ns(**kwargs): + """Create a lightweight attribute object for tests.""" + return SimpleNamespace(**kwargs) + + +# --------------------------------------------------------------------------- +# Fake helpers (same pattern as test_omni_llm.py) +# --------------------------------------------------------------------------- + + +class _FakeEngineArgs(dict): + """Fake engine args that supports both attribute and dict access.""" + + def __init__(self, args_dict: dict[str, Any]): + super().__init__(args_dict) + if "model_stage" not in self: + self["model_stage"] = None + if "engine_output_type" not in self: + self["engine_output_type"] = None + for key, value in self.items(): + setattr(self, key, value) + + +class _FakeStageConfig: + def __init__(self, config_dict: dict[str, Any]): + engine_args_dict = config_dict.get("engine_args", {}) + self.engine_args = _FakeEngineArgs(engine_args_dict) + self.final_output = config_dict.get("final_output", False) + self.final_output_type = config_dict.get("final_output_type", None) + self.stage_id = config_dict.get("stage_id", 0) + self.is_prefill_only = config_dict.get("is_prefill_only", False) + self.is_decode_only = config_dict.get("is_decode_only", False) + self.engine_input_source = config_dict.get("engine_input_source", []) + self.is_comprehension = config_dict.get("is_comprehension", False) + self._config_dict = config_dict + + +class _FakeQueue: + def __init__(self, maxsize=0): + self._queue = Queue(maxsize=maxsize) + + def put(self, item): + self._queue.put(item) + + def put_nowait(self, item): + self._queue.put_nowait(item) + + def get(self): + return self._queue.get() + + def get_nowait(self): + return self._queue.get_nowait() + + def empty(self): + return self._queue.empty() + + +class _FakeStage: + """Lightweight stage stub with PD disaggregation flag support.""" + + def __init__(self, config, stage_init_timeout: int = 300): + if isinstance(config, dict): + config = _FakeStageConfig(config) + self.config = config + self.stage_config = config + self.engine = None + self.engine_outputs = None + self.stage_id = getattr(config, "stage_id", 0) + self.engine_args = config.engine_args + self.model_stage = getattr(config.engine_args, "model_stage", None) + self.stage_type = "llm" + self.default_sampling_params = SamplingParams(temperature=1.0) + self.final_output = config.final_output if hasattr(config, "final_output") else False + self.final_output_type = getattr(config, "final_output_type", None) + self.is_prefill_only = getattr(config, "is_prefill_only", False) + self.is_decode_only = getattr(config, "is_decode_only", False) + self.engine_input_source = getattr(config, "engine_input_source", []) + self.is_comprehension = getattr(config, "is_comprehension", False) + processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"]) + self._processed_input = processed_input + self._in_q = None + self._out_q = None + self._proc = None + self._stage_init_timeout = max(0, int(stage_init_timeout)) + + def attach_queues(self, in_q, out_q): + self._in_q = in_q + self._out_q = out_q + + def init_stage_worker( + self, model: str, *, is_async=False, shm_threshold_bytes=65536, ctx=None, batch_timeout=10, **kwargs + ): + self._proc = _ns( + start=lambda: None, + join=lambda timeout=None: None, + is_alive=lambda: False, + terminate=lambda: None, + ) + if self._out_q is not None: + try: + self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id}) + except Exception: + pass + + def stop_stage_worker(self): + if self._in_q is not None: + try: + self._in_q.put_nowait({"type": "shutdown"}) + except Exception: + pass + + def submit(self, payload: dict[str, Any]): + if self._in_q is not None: + self._in_q.put(payload) + + def try_collect(self) -> Any: + if self._out_q is None: + return None + try: + return self._out_q.get_nowait() + except Empty: + return None + + def set_engine_outputs(self, outputs): + self.engine_outputs = outputs + + def process_engine_inputs(self, stage_list, prompts): + return self._processed_input + + +# --------------------------------------------------------------------------- +# Shared mock setup helpers +# --------------------------------------------------------------------------- + + +def _setup_engine_mocks(monkeypatch): + fake_engine = _ns() + fake_engine.tokenizer = _ns() + fake_engine.log_stats = False + fake_engine.vllm_config = _ns() + fake_engine.vllm_config.model_config = _ns() + fake_engine.vllm_config.model_config.io_processor_plugin = None + fake_engine.get_supported_tasks = lambda: [] + fake_engine.model_config = _ns() + fake_engine.model_config.io_processor_plugin = None + fake_registry = _ns() + fake_registry.resolve_model_cls = lambda *args, **kwargs: (_ns(), "test_arch") + fake_engine.model_config.registry = fake_registry + fake_engine.vllm_config.model_config.registry = fake_registry + + monkeypatch.setattr( + "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args", + lambda **kw: fake_engine, + raising=False, + ) + + class FakeModelClass: + pass + + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils.get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils._get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + monkeypatch.setattr( + "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls", + lambda model_cls: model_cls, + raising=False, + ) + monkeypatch.setattr( + "vllm.multimodal.cache._enable_processor_cache", + lambda model_config, mm_registry: False, + raising=False, + ) + monkeypatch.setattr( + "vllm.plugins.io_processors.get_io_processor", + lambda vllm_config, io_processor_plugin: None, + raising=False, + ) + + +def _setup_multiprocessing_mocks(monkeypatch): + import multiprocessing as mp + + fake_process_instance = _ns( + start=lambda: None, + join=lambda timeout=None: None, + is_alive=lambda: False, + terminate=lambda: None, + ) + + def fake_process_class(*args, **kwargs): + return fake_process_instance + + fake_ctx = _ns() + fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize) + fake_ctx.Process = fake_process_class + + monkeypatch.setattr(mp, "get_context", lambda method: fake_ctx, raising=False) + monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) + + +def _setup_ipc_mocks(monkeypatch): + # These IPC helpers existed in the old architecture; no-op in new arch. + pass + + +def _setup_log_mocks(monkeypatch): + class _FakeOrchestratorAggregator: + def __init__(self, num_stages, enable_stats, wall_start_ts, final_stage_id_for_e2e=None): + self.num_stages = num_stages + self.enable_stats = enable_stats + self.stage_first_ts = [None] * num_stages + self.stage_last_ts = [None] * num_stages + self.stage_total_tokens = [0] * num_stages + self.accumulated_gen_time_ms = {} + self.e2e_done = set() + self.e2e_count = 0 + self.e2e_total_ms = 0.0 + + def on_stage_metrics(self, stage_id, req_id, metrics, final_output_type=None): + pass + + def on_finalize_request(self, stage_id, req_id, start_ts): + self.e2e_done.add(req_id) + + def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm): + pass + + def accumulate_diffusion_metrics(self, stage_type, req_id, engine_outputs): + pass + + def record_audio_generated_frames(self, output, stage_id, req_id): + pass + + def stage_postprocess_timer(self, stage_id, req_id): + from contextlib import contextmanager + + @contextmanager + def _noop(): + yield + + return _noop() + + def build_and_log_summary(self): + return "Fake summary" + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni.OrchestratorAggregator", + _FakeOrchestratorAggregator, + raising=False, + ) + + +def _clear_modules(): + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + +@pytest.fixture(autouse=True) +def mock_get_config(monkeypatch): + """Auto-mock get_config and related model loading functions.""" + import sys + + fake_tokenizer = _ns() + fake_tokenizer.encode = lambda *args, **kwargs: [1, 2, 3] + fake_tokenizer.decode = lambda *args, **kwargs: "test" + + def _mock_init_tokenizer_from_configs(model_config=None, **kwargs): + return fake_tokenizer + + monkeypatch.setattr( + "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs", + _mock_init_tokenizer_from_configs, + raising=False, + ) + tokenizer_module_path = "vllm.transformers_utils.tokenizer" + if tokenizer_module_path in sys.modules: + setattr(sys.modules[tokenizer_module_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None): + if prompt_token_ids is not None: + if isinstance(prompt_token_ids, list): + return len(prompt_token_ids) + return 10 + + monkeypatch.setattr( + "vllm.utils.length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds, raising=False + ) + monkeypatch.setattr( + "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + raising=False, + ) + + processor_module_path = "vllm_omni.engine.input_processor" + if processor_module_path in sys.modules: + setattr( + sys.modules[processor_module_path], + "length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + ) + + monkeypatch.setattr( + "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", _mock_init_tokenizer_from_configs, raising=False + ) + async_omni_path = "vllm_omni.entrypoints.async_omni" + if async_omni_path in sys.modules: + setattr(sys.modules[async_omni_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + fake_hf_config = _ns() + fake_hf_config.model_type = "qwen2_5_omni" + + monkeypatch.setattr( + "vllm.transformers_utils.config.get_config", lambda model, **kwargs: fake_hf_config, raising=False + ) + monkeypatch.setattr("vllm_omni.entrypoints.utils.get_config", lambda model, **kwargs: fake_hf_config, raising=False) + + def _mock_cached_file(path_or_repo_id, *args, **kwargs): + import os + import tempfile + + fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json") + if not os.path.exists(fake_config_file): + with open(fake_config_file, "w") as f: + f.write('{"model_type": "qwen2_5_omni"}') + return fake_config_file + + monkeypatch.setattr("transformers.utils.hub.cached_file", _mock_cached_file, raising=False) + monkeypatch.setattr( + "transformers.utils.hub.cached_files", + lambda path_or_repo_id, filenames, **kwargs: ( + [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None + ), + raising=False, + ) + + +# --------------------------------------------------------------------------- +# Helper to build an Omni instance with PD stage configs +# --------------------------------------------------------------------------- + + +def _make_pd_omni(monkeypatch, stage_configs, *, extra_setup=None): + """Create a lightweight PDDisaggregationMixin instance for unit tests. + + Bypasses the full OmniBase / AsyncOmniEngine init chain so tests run + without GPU. Returns an object that has all PDDisaggregationMixin + methods and state (``_pd_separation_pair``, ``_pd_kv_params_by_req``, + etc.) initialised from *stage_configs*. + + Tests that need the full ``Omni.generate()`` loop (old stage_list / queue + infrastructure) are marked ``xfail`` and not covered here. + """ + configs = [_FakeStageConfig(c) for c in stage_configs] + + class _LightweightOmni(PDDisaggregationMixin): + """Minimal shim: exposes stage_configs so PDDisaggregationMixin works.""" + + def __init__(self): + self._name = "Omni" + self._stage_configs = configs + self._init_pd_state() + + @property + def stage_configs(self): + return self._stage_configs + + if extra_setup: + import vllm_omni.entrypoints.omni as omni_module + + extra_setup(monkeypatch, omni_module) + + return _LightweightOmni() + + +# --------------------------------------------------------------------------- +# Stage config templates +# --------------------------------------------------------------------------- + + +def _prefill_stage_cfg(stage_id=0, **overrides): + cfg = { + "stage_id": stage_id, + "engine_args": { + "model_stage": "thinker", + "kv_transfer_config": { + "kv_connector": "MooncakeConnector", + "kv_role": "kv_producer", + "kv_rank": 0, + "kv_parallel_size": 2, + "kv_connector_extra_config": {"mooncake_bootstrap_port": 25201}, + }, + }, + "is_prefill_only": True, + "final_output": False, + "is_comprehension": True, + } + cfg.update(overrides) + return cfg + + +def _decode_stage_cfg(stage_id=1, engine_input_source=None, **overrides): + cfg = { + "stage_id": stage_id, + "engine_args": { + "model_stage": "thinker", + "kv_transfer_config": { + "kv_connector": "MooncakeConnector", + "kv_role": "kv_consumer", + "kv_rank": 1, + "kv_parallel_size": 2, + "kv_connector_extra_config": {"mooncake_bootstrap_port": 25202}, + }, + }, + "is_decode_only": True, + "engine_input_source": engine_input_source if engine_input_source is not None else [0], + "final_output": True, + "final_output_type": "text", + "is_comprehension": True, + } + cfg.update(overrides) + return cfg + + +def _talker_stage_cfg(stage_id=2, engine_input_source=None, **overrides): + cfg = { + "stage_id": stage_id, + "engine_args": {"model_stage": "talker"}, + "engine_input_source": engine_input_source if engine_input_source is not None else [1], + "final_output": False, + } + cfg.update(overrides) + return cfg + + +def _code2wav_stage_cfg(stage_id=3, engine_input_source=None, **overrides): + cfg = { + "stage_id": stage_id, + "engine_args": {"model_stage": "code2wav"}, + "engine_input_source": engine_input_source if engine_input_source is not None else [2], + "final_output": True, + "final_output_type": "audio", + } + cfg.update(overrides) + return cfg + + +# =================================================================== +# Tests: PD pair detection +# =================================================================== + + +class TestDetectPDSeparation: + """Tests for Omni._detect_pd_separation().""" + + def test_detects_pd_pair(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + ], + ) + assert omni._pd_separation_pair == (0, 1) + + def test_no_pd_pair_without_flags(self, monkeypatch): + """Normal (non-PD) pipeline has no PD pair.""" + omni = _make_pd_omni( + monkeypatch, + [ + { + "stage_id": 0, + "engine_args": {"model_stage": "thinker"}, + "final_output": True, + "final_output_type": "text", + }, + { + "stage_id": 1, + "engine_args": {"model_stage": "talker"}, + "engine_input_source": [0], + "final_output": True, + "final_output_type": "audio", + }, + ], + ) + assert omni._pd_separation_pair is None + + def test_detects_pd_pair_in_4_stage_pipeline(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + _talker_stage_cfg(stage_id=2, engine_input_source=[1]), + _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]), + ], + ) + assert omni._pd_separation_pair == (0, 1) + + def test_pd_pair_uses_stage_id_for_input_source(self, monkeypatch): + """engine_input_source references stage_id, not list index.""" + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=10), + _decode_stage_cfg(stage_id=20, engine_input_source=[10]), + ], + ) + assert omni._pd_separation_pair == (0, 1) + + +# =================================================================== +# Tests: PD config validation +# =================================================================== + + +class TestValidatePDConfig: + """Tests for Omni._validate_pd_separation_config().""" + + def test_valid_config_passes(self, monkeypatch): + """Valid PD config should not raise.""" + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + # If we got here without error, validation passed + assert omni._pd_separation_pair == (0, 1) + + def test_mismatched_connector_raises(self, monkeypatch): + """Different kv_connector types should raise ValueError.""" + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["kv_transfer_config"]["kv_connector"] = "NixlConnector" + + with pytest.raises(ValueError, match="connector mismatch"): + _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg]) + + def test_wrong_prefill_role_raises(self, monkeypatch): + """Prefill with kv_consumer role should raise.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_consumer" + + with pytest.raises(ValueError, match="kv_role must be"): + _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])]) + + def test_wrong_decode_role_raises(self, monkeypatch): + """Decode with kv_producer role should raise.""" + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_producer" + + with pytest.raises(ValueError, match="kv_role must be"): + _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg]) + + def test_missing_kv_transfer_config_raises(self, monkeypatch): + """Missing kv_transfer_config should raise.""" + prefill_cfg = _prefill_stage_cfg() + del prefill_cfg["engine_args"]["kv_transfer_config"] + + with pytest.raises(ValueError, match="kv_transfer_config"): + _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])]) + + def test_mismatched_buffer_device_raises(self, monkeypatch): + """Mismatched kv_buffer_device should raise.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cuda" + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cpu" + + with pytest.raises(ValueError, match="kv_buffer_device mismatch"): + _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg]) + + +# =================================================================== +# Tests: Connector info extraction +# =================================================================== + + +class TestGetPDConnectorInfo: + """Tests for Omni._get_pd_connector_info().""" + + def test_extracts_bootstrap_addr_for_mooncake(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + info = omni._pd_connector_info + assert "prefill_bootstrap_addr" in info + assert info["prefill_bootstrap_addr"] == "127.0.0.1:25201" + + def test_none_for_non_pd_pipeline(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + {"stage_id": 0, "engine_args": {}, "final_output": True, "final_output_type": "text"}, + ], + ) + assert omni._pd_connector_info is None + + +# =================================================================== +# Tests: Prefill sampling params preparation +# =================================================================== + + +class TestPreparePrefillSamplingParams: + """Tests for Omni._prepare_prefill_sampling_params().""" + + def test_sets_max_tokens_to_1(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048) + result = omni._prepare_prefill_sampling_params("req-1", sp) + + assert result.max_tokens == 1 + assert result is not sp # should be cloned + + def test_injects_kv_transfer_params(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048) + result = omni._prepare_prefill_sampling_params("req-1", sp) + + kv_params = result.extra_args["kv_transfer_params"] + assert kv_params["do_remote_decode"] is True + assert kv_params["do_remote_prefill"] is False + assert kv_params["transfer_id"] == "xfer-req-1" + + def test_preserves_existing_extra_args(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, extra_args={"custom_key": "value"}) + result = omni._prepare_prefill_sampling_params("req-1", sp) + + assert result.extra_args["custom_key"] == "value" + assert "kv_transfer_params" in result.extra_args + + def test_does_not_mutate_original(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048) + _ = omni._prepare_prefill_sampling_params("req-1", sp) + + assert sp.max_tokens == 2048 + assert sp.extra_args is None + + +# =================================================================== +# Tests: Sampling params auto-duplication for PD split +# =================================================================== + + +@pytest.mark.xfail(reason="Requires migration to v1908 Orchestrator architecture (no stage_list / OmniStage)") +class TestSamplingParamsAutoDuplication: + """When user provides N-1 sampling params (for logical stages), the + orchestrator should auto-duplicate the thinker params for the decode stage. + """ + + def test_auto_duplicates_for_4_stage_pipeline(self, monkeypatch): + """User provides 3 params for 4 physical stages -> auto-insert decode params.""" + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000001") + + def _extra_setup(mp, omni_module): + mp.setattr(uuid, "uuid4", lambda: test_uuid) + mp.setattr(omni_module, "uuid", uuid) + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + _talker_stage_cfg(stage_id=2, engine_input_source=[1]), + _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]), + ], + extra_setup=_extra_setup, + ) + + assert omni._pd_separation_pair == (0, 1) + assert len(omni.stage_list) == 4 + + # Simulate outputs for all stages + expected_rid = f"0_{test_uuid}" + for i in range(4): + omni.stage_list[i]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2])])], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + # Provide 3 params (one less than 4 stages) - should auto-duplicate + sp_thinker = SamplingParams(temperature=0.4, max_tokens=2048) + sp_talker = SamplingParams(temperature=0.9, max_tokens=4096) + sp_code2wav = SamplingParams(temperature=0.0, max_tokens=65536) + + # This should NOT raise ValueError about param count mismatch + outputs = omni.generate( + prompts=["hello"], + sampling_params_list=[sp_thinker, sp_talker, sp_code2wav], + ) + assert isinstance(outputs, list) + + +# =================================================================== +# Tests: KV transfer params normalization +# =================================================================== + + +class TestNormalizeKVTransferParams: + def test_dict_passthrough(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + d = {"transfer_id": "test", "do_remote_decode": True} + assert omni._normalize_kv_transfer_params(d) is d + + def test_none_returns_none(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + assert omni._normalize_kv_transfer_params(None) is None + + def test_dataclass_to_dict(self, monkeypatch): + from dataclasses import dataclass + + @dataclass + class FakeKVParams: + transfer_id: str = "test" + do_remote_decode: bool = True + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + result = omni._normalize_kv_transfer_params(FakeKVParams()) + assert isinstance(result, dict) + assert result["transfer_id"] == "test" + + +# =================================================================== +# Tests: _kv_cfg_to_dict +# =================================================================== + + +class TestKvCfgToDict: + def test_dict_passthrough(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + d = {"kv_connector": "MooncakeConnector"} + assert omni._kv_cfg_to_dict(d) is d + + def test_none_returns_empty(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + assert omni._kv_cfg_to_dict(None) == {} + + def test_dataclass_converted(self, monkeypatch): + from dataclasses import dataclass + + @dataclass + class FakeCfg: + kv_connector: str = "TestConnector" + kv_role: str = "kv_producer" + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + result = omni._kv_cfg_to_dict(FakeCfg()) + assert result["kv_connector"] == "TestConnector" + assert result["kv_role"] == "kv_producer" + + +# =================================================================== +# Tests: PD routing in scheduling loop +# =================================================================== + + +@pytest.mark.xfail(reason="Requires migration to v1908 Orchestrator architecture (no stage_list / OmniStage)") +class TestPDRouting: + """Test that the scheduling loop correctly routes requests from + prefill to decode stage with proper kv_transfer_params. + """ + + def test_prefill_stage_receives_max_tokens_1(self, monkeypatch): + """Stage 0 (prefill) should receive max_tokens=1.""" + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000002") + + def _extra_setup(mp, omni_module): + mp.setattr(uuid, "uuid4", lambda: test_uuid) + mp.setattr(omni_module, "uuid", uuid) + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + ], + extra_setup=_extra_setup, + ) + + expected_rid = f"0_{test_uuid}" + + # Put stage outputs in both queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])], + "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0}, + } + ) + + sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)] + omni.generate(prompts=["hello"], sampling_params_list=sp_list) + + # Check what was submitted to stage 0's input queue + # (skip the stage_ready message first) + task = omni.stage_list[0]._in_q.get_nowait() + assert task["sampling_params"].max_tokens == 1 + kv_params = task["sampling_params"].extra_args["kv_transfer_params"] + assert kv_params["do_remote_decode"] is True + assert kv_params["do_remote_prefill"] is False + assert kv_params["transfer_id"] == f"xfer-{expected_rid}" + + def test_decode_stage_receives_original_prompt(self, monkeypatch): + """Decode stage should get the original prompt (not processed outputs).""" + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000003") + + def _extra_setup(mp, omni_module): + mp.setattr(uuid, "uuid4", lambda: test_uuid) + mp.setattr(omni_module, "uuid", uuid) + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + ], + extra_setup=_extra_setup, + ) + + expected_rid = f"0_{test_uuid}" + original_prompt = "test prompt for PD" + + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])], + "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0}, + } + ) + + sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)] + omni.generate(prompts=[original_prompt], sampling_params_list=sp_list) + + # Check what was forwarded to stage 1 (decode) + # The connector sends tasks to stage 1's input queue + task = omni.stage_list[1]._in_q.get_nowait() + # The engine_inputs should contain the original prompt + engine_inputs = task.get("engine_inputs") + # For PD routing, the original prompt is wrapped in a list + if isinstance(engine_inputs, list): + assert original_prompt in engine_inputs + else: + assert engine_inputs == original_prompt + + def test_decode_kv_params_have_correct_flags(self, monkeypatch): + """Decode stage kv_transfer_params should have correct role flags.""" + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000004") + + def _extra_setup(mp, omni_module): + mp.setattr(uuid, "uuid4", lambda: test_uuid) + mp.setattr(omni_module, "uuid", uuid) + + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(stage_id=0), + _decode_stage_cfg(stage_id=1, engine_input_source=[0]), + ], + extra_setup=_extra_setup, + ) + + expected_rid = f"0_{test_uuid}" + + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_rid, + "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])], + "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0}, + } + ) + + sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)] + omni.generate(prompts=["hello"], sampling_params_list=sp_list) + + # Check decode task's kv_transfer_params + task = omni.stage_list[1]._in_q.get_nowait() + kv_params = task["sampling_params"].extra_args["kv_transfer_params"] + assert kv_params["do_remote_prefill"] is True + assert kv_params["do_remote_decode"] is False + assert kv_params["transfer_id"] == f"xfer-{expected_rid}" + assert kv_params["remote_bootstrap_addr"] == "127.0.0.1:25201" + + +# =================================================================== +# Tests: KV params cleanup +# =================================================================== + + +class TestKVParamsCleanup: + def test_drop_cleans_up(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + omni._pd_kv_params_by_req["req-1"] = {"transfer_id": "xfer-1"} + omni._drop_pd_kv_params("req-1") + assert "req-1" not in omni._pd_kv_params_by_req + + def test_drop_nonexistent_is_noop(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + omni._drop_pd_kv_params("nonexistent") # should not raise + + def test_pop_returns_stored_params(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + stored = {"transfer_id": "xfer-1", "extra_field": "value"} + omni._pd_kv_params_by_req["req-1"] = stored + + result = omni._pop_pd_kv_params("req-1") + assert result == stored + assert "req-1" not in omni._pd_kv_params_by_req + + def test_pop_uses_fallback_when_no_stored(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + fallback = {"transfer_id": "xfer-fallback"} + result = omni._pop_pd_kv_params("req-1", fallback=fallback) + assert result == fallback + + +# =================================================================== +# Tests: Config YAML loads without error +# =================================================================== + + +class TestPDYAMLConfig: + def test_pd_yaml_loads(self): + """The PD separation YAML config should load without errors.""" + import os + + yaml_path = os.path.join( + os.path.dirname(__file__), + "../../vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml", + ) + yaml_path = os.path.abspath(yaml_path) + if not os.path.exists(yaml_path): + pytest.skip("PD separation YAML not found") + + from omegaconf import OmegaConf + + cfg = OmegaConf.load(yaml_path) + stages = cfg.stage_args + assert len(stages) == 4 + + # Prefill stage + assert stages[0].is_prefill_only is True + assert stages[0].final_output is False + assert stages[0].is_comprehension is True + + # Decode stage + assert stages[1].is_decode_only is True + assert stages[1].final_output is True + assert stages[1].final_output_type == "text" + assert stages[1].is_comprehension is True + assert 0 in stages[1].engine_input_source + + # KV transfer configs + assert stages[0].engine_args.kv_transfer_config.kv_role == "kv_producer" + assert stages[1].engine_args.kv_transfer_config.kv_role == "kv_consumer" + assert stages[0].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector" + assert stages[1].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector" + + +class TestPrefillStopNeutralization: + """Tests that _prepare_prefill_sampling_params neutralizes stop + conditions to ensure finish_reason='length'. + """ + + def test_clears_stop_strings(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, stop=["", "STOP"]) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.stop == [] + + def test_clears_stop_token_ids(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, stop_token_ids=[151643, 151644]) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.stop_token_ids == [] + + def test_clears_include_stop_str_in_output(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, include_stop_str_in_output=True) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.include_stop_str_in_output is False + + def test_original_sp_unchanged(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, stop=[""], stop_token_ids=[151643]) + _ = omni._prepare_prefill_sampling_params("req-1", sp) + assert sp.stop == [""] + assert sp.stop_token_ids == [151643] + + +# =================================================================== +# Tests: Failure mode & memory leak prevention +# =================================================================== +# NOTE: Full generate()-level failure mode tests are removed for now. +# The _run_generation error handler (line 1344-1350 in omni.py) calls +# _drop_pd_kv_params but does not increment completed_requests, causing +# the while-loop to hang. These tests need to be revisited once the +# production error-handling path is fixed to properly terminate on +# stage errors. + + +# =================================================================== +# Tests: TP size validation +# =================================================================== + + +class TestTPSizeValidation: + """Tests that _validate_pd_separation_config checks tensor_parallel_size.""" + + def test_matching_tp_passes(self, monkeypatch): + """Same TP size should not raise.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["tensor_parallel_size"] = 2 + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["tensor_parallel_size"] = 2 + omni = _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg]) + assert omni._pd_separation_pair == (0, 1) + + def test_mismatched_tp_raises(self, monkeypatch): + """Different TP sizes should raise ValueError.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["tensor_parallel_size"] = 2 + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["tensor_parallel_size"] = 4 + with pytest.raises(ValueError, match="tensor_parallel_size"): + _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg]) + + def test_default_tp_no_error(self, monkeypatch): + """Stages without explicit TP (defaults to 1) should pass.""" + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + assert omni._pd_separation_pair == (0, 1) diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 238bbdcdbd4..fe6527349c5 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -25,6 +25,7 @@ import janus import torch from omegaconf import OmegaConf +from vllm import envs as vllm_envs from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.logger import init_logger @@ -79,6 +80,7 @@ setup_stage_devices, terminate_alive_proc, ) +from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin from vllm_omni.entrypoints.utils import ( inject_omni_kv_config, load_and_resolve_stage_configs, @@ -902,6 +904,7 @@ async def _run_orchestrator() -> None: self._initialize_janus_queues() self._initialize_stages(stage_init_timeout) + pd_config = self._detect_pd_config() orchestrator = Orchestrator( request_async_queue=self.request_queue.async_q, output_async_queue=self.output_queue.async_q, @@ -910,6 +913,7 @@ async def _run_orchestrator() -> None: stage_clients=self.stage_clients, output_processors=self.output_processors, stage_vllm_configs=self.stage_vllm_configs, + pd_config=pd_config, ) if not startup_future.done(): startup_future.set_result(asyncio.get_running_loop()) @@ -1131,6 +1135,48 @@ def _normalize_cache_config(cache_backend: str | None, cache_config: Any | None) cache_config = AsyncOmniEngine._get_default_cache_config(cache_backend) return cache_config + def _detect_pd_config(self) -> dict[str, Any] | None: + """Detect PD (Prefill-Decode) disaggregation config from stage_configs. + Returns a dict with 'pd_pair' and 'bootstrap_addr', or None. + """ + pd_pair = PDDisaggregationMixin.detect_pd_separation_from_stage_configs(self.stage_configs) + if pd_pair is None: + return None + prefill_idx, decode_idx = pd_pair + + # Extract bootstrap address from prefill stage engine_args + bootstrap_addr: str | None = None + try: + prefill_cfg = self.stage_configs[prefill_idx] + ea = getattr(prefill_cfg, "engine_args", None) + kv_cfg = getattr(ea, "kv_transfer_config", None) if ea is not None else None + if kv_cfg is not None: + port = vllm_envs.VLLM_MOONCAKE_BOOTSTRAP_PORT + kv_ip = getattr(kv_cfg, "kv_ip", None) or "127.0.0.1" + bootstrap_addr = f"http://{kv_ip}:{port}" + except Exception as exc: + logger.warning("[AsyncOmniEngine] Could not extract PD bootstrap address: %s", exc) + + logger.info( + "[AsyncOmniEngine] PD disaggregation detected: prefill=stage-%d, decode=stage-%d, bootstrap=%s", + prefill_idx, + decode_idx, + bootstrap_addr, + ) + prefill_engine_id: str | None = None + try: + prefill_client = self.stage_clients[prefill_idx] + kv_cfg = getattr(getattr(prefill_client, "vllm_config", None), "kv_transfer_config", None) + prefill_engine_id = getattr(kv_cfg, "engine_id", None) + except Exception as exc: + logger.warning("[AsyncOmniEngine] Could not extract prefill engine_id: %s", exc) + + return { + "pd_pair": (prefill_idx, decode_idx), + "bootstrap_addr": bootstrap_addr, + "prefill_engine_id": prefill_engine_id, + } + @staticmethod def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: """Create a default single-stage diffusion config from kwargs.""" diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 1de4282fea2..8204d70e68a 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -42,6 +42,7 @@ def build_engine_core_request_from_tokens( params: SamplingParams | PoolingParams, arrival_time: float | None = None, model_config: ModelConfig | None = None, + mm_features: list | None = None, ) -> OmniEngineCoreRequest: """Build an OmniEngineCoreRequest directly from an OmniTokensPrompt. @@ -76,7 +77,7 @@ def build_engine_core_request_from_tokens( return OmniEngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_token_ids, - mm_features=None, + mm_features=mm_features, sampling_params=sampling_params, pooling_params=pooling_params, arrival_time=arrival_time, @@ -104,6 +105,8 @@ class OrchestratorRequestState: # Metrics: timestamp when request was submitted to each stage stage_submit_ts: dict[int, float] = field(default_factory=dict) + mm_processor_kwargs: dict | None = None + mm_features: list | None = None class Orchestrator: @@ -123,6 +126,7 @@ def __init__( stage_vllm_configs: list[Any], *, async_chunk: bool = False, + pd_config: dict[str, Any] | None = None, ) -> None: self.request_async_queue = request_async_queue self.output_async_queue = output_async_queue @@ -135,6 +139,16 @@ def __init__( self.output_processors: list[Any] = output_processors self.stage_vllm_configs: list[Any] = stage_vllm_configs + # PD disaggregation state + self._pd_pair: tuple[int, int] | None = None + self._pd_bootstrap_addr: str | None = None + self._pd_prefill_engine_id: str | None = None + self._pd_kv_params: dict[str, Any] = {} + if pd_config is not None: + self._pd_pair = pd_config.get("pd_pair") + self._pd_bootstrap_addr = pd_config.get("bootstrap_addr") + self._pd_prefill_engine_id = pd_config.get("prefill_engine_id") + # Per-request state self.request_states: dict[str, OrchestratorRequestState] = {} @@ -359,6 +373,23 @@ async def _route_output( } ) + # PD disaggregation: extract KV transfer params from prefill stage output + if self._pd_pair is not None and finished and stage_id == self._pd_pair[0]: + kv_params = getattr(output, "kv_transfer_params", None) + if kv_params is not None: + self._pd_kv_params[req_id] = kv_params if isinstance(kv_params, dict) else dict(kv_params) + logger.debug( + "[Orchestrator][PD] stored kv_transfer_params for req=%s (keys=%s)", + req_id, + list(self._pd_kv_params[req_id].keys()), + ) + else: + logger.warning( + "[Orchestrator][PD] prefill stage output for req=%s has no kv_transfer_params; " + "KV transfer may fail. Ensure apply_mooncake_connector_patch() was called.", + req_id, + ) + if ( finished and stage_id < req_state.final_stage_id @@ -371,6 +402,8 @@ async def _route_output( await self._forward_to_next_stage(req_id, stage_id, output, req_state) if finished and stage_id == req_state.final_stage_id: + # PD: clean up any lingering KV params for this request + self._pd_kv_params.pop(req_id, None) self._cfg_tracker.cleanup_parent(req_id) self.request_states.pop(req_id, None) @@ -418,6 +451,51 @@ async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineC else: await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state) + def _build_pd_decode_params(self, req_id: str, sp: Any) -> Any: + """Build decode-side sampling params with KV transfer params for PD routing. + + Clones the sampling params and injects kv_transfer_params that tell the + decode engine where to pull the KV cache from (prefill engine's bootstrap addr). + """ + sp = sp.clone() + if sp.extra_args is None: + sp.extra_args = {} + + # Get KV params captured from the prefill output (must include remote_request_id). + kv_prefill_params = self._pd_kv_params.pop(req_id, None) + if not kv_prefill_params or "remote_request_id" not in kv_prefill_params: + raise RuntimeError( + f"[Orchestrator][PD] Missing prefill kv_transfer_params.remote_request_id for req={req_id}" + ) + + decode_kv_params: dict[str, Any] = { + "transfer_id": f"xfer-{req_id}", + } + + if self._pd_bootstrap_addr: + decode_kv_params["remote_bootstrap_addr"] = self._pd_bootstrap_addr + + if self._pd_prefill_engine_id: + decode_kv_params["remote_engine_id"] = self._pd_prefill_engine_id + + # Overlay params from prefill side (includes remote_request_id set by monkey patch). + decode_kv_params.update(kv_prefill_params) + + # Ensure these flags are set correctly after any overlay. + decode_kv_params["do_remote_prefill"] = True + decode_kv_params["do_remote_decode"] = False + if not decode_kv_params.get("transfer_id"): + decode_kv_params["transfer_id"] = f"xfer-{req_id}" + + sp.extra_args["kv_transfer_params"] = decode_kv_params + + logger.debug( + "[Orchestrator][PD] decode kv_transfer_params for req=%s: %s", + req_id, + decode_kv_params, + ) + return sp + def _build_stage_metrics( self, stage_id: int, @@ -540,6 +618,52 @@ async def _forward_to_next_stage( req_state.stage_submit_ts[next_stage_id] = _time.time() return + # PD disaggregation: prefill → decode routing uses original prompt + KV transfer params + if self._pd_pair is not None and (stage_id, next_stage_id) == self._pd_pair: + # Save prefill stage outputs so thinker2talker can merge embeddings later + self.stage_clients[stage_id].set_engine_outputs([output]) + + params = self._build_pd_decode_params(req_id, params) + + # Use the original user prompt for the decode stage (not processed embeddings) + original_prompt = req_state.prompt + raw_decode_inputs = [original_prompt] if not isinstance(original_prompt, list) else original_prompt + + decode_inputs: list[dict[str, Any]] = [] + for decode_input in raw_decode_inputs: + if isinstance(decode_input, dict): + decode_inputs.append(decode_input) + continue + prompt_token_ids = getattr(decode_input, "prompt_token_ids", None) + if prompt_token_ids is None: + raise TypeError( + "[Orchestrator][PD] decode input must be dict or have prompt_token_ids, " + f"got {type(decode_input).__name__} for req={req_id}" + ) + decode_inputs.append({"prompt_token_ids": list(prompt_token_ids)}) + + for decode_input in decode_inputs: + request = build_engine_core_request_from_tokens( + request_id=req_id, + prompt=decode_input, + params=params, + model_config=self.stage_vllm_configs[next_stage_id].model_config, + mm_features=req_state.mm_features, # Pass mm_features for M-RoPE + ) + request.external_req_id = request.request_id + + self.output_processors[next_stage_id].add_request( + request=request, + prompt=None, + parent_req=None, + request_index=0, + queue=None, + ) + await next_client.add_request_async(request) + + req_state.stage_submit_ts[next_stage_id] = _time.time() + return + self.stage_clients[stage_id].set_engine_outputs([output]) # Process inputs for next stage @@ -558,11 +682,16 @@ async def _forward_to_next_stage( # Build and submit requests for each input for next_input in next_inputs: + # Only AR thinker stages consume encoder mm_features; downstream + # (talker/code2wav/…) must not see them (avoids encoder-cache misses). + _ms = getattr(next_client, "model_stage", None) + _mm_features = req_state.mm_features if _ms == "thinker" else None request = build_engine_core_request_from_tokens( request_id=req_id, prompt=next_input, params=params, model_config=self.stage_vllm_configs[next_stage_id].model_config, + mm_features=_mm_features, ) # TODO: Here we directly use the req id to assign. @@ -644,6 +773,7 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None: prompt=original_prompt, sampling_params_list=sampling_params_list, final_stage_id=final_stage_id, + mm_features=getattr(prompt, "mm_features", None), # Save mm_features for PD ) req_state.stage_submit_ts[stage_id] = _time.time() self.request_states[request_id] = req_state diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index badd799fc94..67b4dd16504 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -233,10 +233,9 @@ def _new_completion_output( # Reuse base text/logprobs logic, then annotate with pooling_result. base_output = super()._new_completion_output(token_ids, finish_reason, stop_reason, routed_experts) try: + if not hasattr(base_output, "multimodal_output"): + setattr(base_output, "multimodal_output", {}) if self.mm_accumulated is not None: - # Attach accumulated multimodal dict on the completion output - if not hasattr(base_output, "multimodal_output"): - setattr(base_output, "multimodal_output", {}) mm_out = getattr(base_output, "multimodal_output") if isinstance(mm_out, dict): for k, v in self.mm_accumulated.items(): diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 129ef3c99d8..5823aa4ab04 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -78,7 +78,6 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None: self.final_output_task: asyncio.Task | None = None self.config_path = self.engine.config_path - self.stage_configs = self.engine.stage_configs self.tts_max_instructions_length = kwargs.get("tts_max_instructions_length", None) self.input_processor = self.engine.input_processor @@ -209,6 +208,13 @@ async def generate( # Start final output dispatcher on the first call to generate() self._final_output_handler() + # Expand sampling params for PD disaggregation (user may provide N-1 params) + if ( + sampling_params_list is not None + and isinstance(sampling_params_list, Sequence) + and not isinstance(sampling_params_list, (str, bytes)) + ): + sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list)) sampling_params_list = self.resolve_sampling_params_list(sampling_params_list) # Track per-request metrics @@ -228,20 +234,27 @@ async def generate( req_state.metrics = metrics self.request_states[request_id] = req_state + # PD disaggregation: modify prefill-stage sampling params per request + req_sp_list = list(sampling_params_list) + pd_pair = self._get_pd_separation_pair() + if pd_pair is not None: + p_id = pd_pair[0] + req_sp_list[p_id] = self._prepare_prefill_sampling_params(request_id, req_sp_list[p_id]) + # Add request(s) to stage 0. For streaming inputs, submit # chunks incrementally through streaming_update. if isinstance(prompt, AsyncGenerator): input_stream_task = await self._add_streaming_input_request( request_id=request_id, input_stream=prompt, - sampling_params_list=sampling_params_list, + sampling_params_list=req_sp_list, final_stage_id=final_stage_id_for_e2e, ) else: await self.engine.add_request_async( request_id=request_id, prompt=prompt, - sampling_params_list=sampling_params_list, + sampling_params_list=req_sp_list, final_stage_id=final_stage_id_for_e2e, ) submit_ts = time.time() diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index a3bfe98ce2c..8ef7e2ee5b7 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -66,6 +66,13 @@ def generate( py_generator: bool = False, use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]: + # Expand sampling params for PD disaggregation (user may provide N-1 params) + if ( + sampling_params_list is not None + and isinstance(sampling_params_list, Sequence) + and not isinstance(sampling_params_list, (str, bytes)) + ): + sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list)) sampling_params_list = self.resolve_sampling_params_list(sampling_params_list) try: if py_generator: @@ -125,10 +132,17 @@ def _run_generation( req_state.metrics = metrics self.request_states[req_id] = req_state + # PD disaggregation: modify stage-0 (prefill) sampling params per request + req_sp_list = list(sampling_params_list) + pd_pair = self._get_pd_separation_pair() + if pd_pair is not None: + p_id = pd_pair[0] + req_sp_list[p_id] = self._prepare_prefill_sampling_params(req_id, req_sp_list[p_id]) + self.engine.add_request( request_id=req_id, prompt=prompt, - sampling_params_list=sampling_params_list, + sampling_params_list=req_sp_list, final_stage_id=final_stage_id, ) submit_ts = time.time() diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index 1a7ffc4a504..82c64892478 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -14,6 +14,7 @@ from vllm_omni.engine.async_omni_engine import AsyncOmniEngine from vllm_omni.entrypoints.client_request_state import ClientRequestState +from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific @@ -65,7 +66,7 @@ def omni_snapshot_download(model_id: str) -> str: OutputMessageHandleResult = tuple[Literal[True], None, None, None] | tuple[Literal[False], str, int, ClientRequestState] -class OmniBase: +class OmniBase(PDDisaggregationMixin): """Shared runtime foundation for AsyncOmni and Omni.""" def __init__( @@ -84,6 +85,7 @@ def __init__( if "log_requests" in kwargs: raise TypeError("`log_requests` has been removed in Omni/AsyncOmni. Use `log_stats`.") model = omni_snapshot_download(model) + self._name = self.__class__.__name__ self.model = model self.log_stats = log_stats self.async_chunk = async_chunk @@ -125,10 +127,18 @@ def __init__( model, ) + # PD disaggregation state (detects if a prefill/decode stage pair is configured) + self._init_pd_state() + @property def num_stages(self) -> int: return self.engine.num_stages + @property + def stage_configs(self) -> list: + """Expose engine stage configs for PD disaggregation detection and validation.""" + return self.engine.stage_configs + @property def is_running(self) -> bool: return self.engine.is_alive() diff --git a/vllm_omni/entrypoints/pd_utils.py b/vllm_omni/entrypoints/pd_utils.py index 0e3d65f5537..413d5d6b448 100644 --- a/vllm_omni/entrypoints/pd_utils.py +++ b/vllm_omni/entrypoints/pd_utils.py @@ -23,9 +23,19 @@ class PDDisaggregationMixin: """Mixin supplying PD disaggregation helpers to OmniBase.""" + def _get_pd_separation_pair(self) -> tuple[int, int] | None: + """PD prefill/decode indices when ``_init_pd_state`` ran; else ``None``. + + Partial test doubles may skip ``OmniBase.__init__``; treat missing state as + no PD disaggregation instead of raising ``AttributeError``. + """ + return getattr(self, "_pd_separation_pair", None) + def _init_pd_state(self) -> None: """Initialise PD disaggregation state.""" - self._pd_separation_pair: tuple[int, int] | None = self._detect_pd_separation() + self._pd_separation_pair: tuple[int, int] | None = self.detect_pd_separation_from_stage_configs( + self.stage_configs + ) self._pd_connector_info: dict[str, Any] | None = None self._pd_kv_params_by_req: dict[str, dict[str, Any]] = {} self._pd_kv_params_lock = threading.Lock() @@ -40,11 +50,19 @@ def _init_pd_state(self) -> None: d_id, ) - def _detect_pd_separation(self) -> tuple[int, int] | None: - """Scan stage_list for a prefill/decode pair. Returns (p_id, d_id) or None.""" + @staticmethod + def detect_pd_separation_from_stage_configs(stage_configs: list[Any]) -> tuple[int, int] | None: + """Scan stage configs for a prefill/decode pair. + + Returns: + (prefill_idx, decode_idx) if one pair exists, None if not found. + + Raises: + ValueError: if multiple candidate PD pairs are found. + """ prefill_by_id: dict[int, int] = {} decode_indices: list[int] = [] - for i, stage in enumerate(self.stage_list): + for i, stage in enumerate(stage_configs): if getattr(stage, "is_prefill_only", False): prefill_by_id[i] = i sid = getattr(stage, "stage_id", i) @@ -55,7 +73,7 @@ def _detect_pd_separation(self) -> tuple[int, int] | None: pd_pairs: list[tuple[int, int]] = [] for j in decode_indices: - source_ids = getattr(self.stage_list[j], "engine_input_source", []) + source_ids = getattr(stage_configs[j], "engine_input_source", []) for src in source_ids: if src in prefill_by_id: pd_pairs.append((prefill_by_id[src], j)) @@ -107,10 +125,11 @@ def _normalize_kv_transfer_params(self, kv_params: Any) -> dict[str, Any] | None def _validate_pd_separation_config(self) -> None: """Validate PD stage configurations are consistent.""" - assert self._pd_separation_pair is not None - p_id, d_id = self._pd_separation_pair - p_stage = self.stage_list[p_id] - d_stage = self.stage_list[d_id] + pair = self._get_pd_separation_pair() + assert pair is not None + p_id, d_id = pair + p_stage = self.stage_configs[p_id] + d_stage = self.stage_configs[d_id] def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]: ea = stage.engine_args @@ -158,11 +177,12 @@ def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]: def _get_pd_connector_info(self) -> dict[str, Any] | None: """Extract prefill engine KV connector info.""" - if self._pd_separation_pair is None: + pair = self._get_pd_separation_pair() + if pair is None: return None - p_id, _ = self._pd_separation_pair - p_stage = self.stage_list[p_id] + p_id, _ = pair + p_stage = self.stage_configs[p_id] ea = p_stage.engine_args kv_cfg = getattr(ea, "kv_transfer_config", None) @@ -241,18 +261,17 @@ def _extract_kv_transfer_params(self, engine_outputs: Any) -> dict[str, Any] | N def _is_pd_routing(self, stage_id: int, next_stage_id: int) -> bool: """True when edge stage_id → next_stage_id is the prefill→decode boundary.""" - return self._pd_separation_pair is not None and self._pd_separation_pair == ( - stage_id, - next_stage_id, - ) + pair = self._get_pd_separation_pair() + return pair is not None and pair == (stage_id, next_stage_id) def _maybe_expand_sampling_params(self, sampling_params_list: list) -> list: """Auto-duplicate thinker SP for decode stage when user provides N-1 params.""" - if self._pd_separation_pair is None: + pair = self._get_pd_separation_pair() + if pair is None: return sampling_params_list - if len(sampling_params_list) != len(self.stage_list) - 1: + if len(sampling_params_list) != len(self.stage_configs) - 1: return sampling_params_list - p_id, d_id = self._pd_separation_pair + p_id, d_id = pair sp_list = list(sampling_params_list) sp_list.insert(d_id, sp_list[p_id]) return sp_list diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 7df69479734..cd8d147ca17 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -891,10 +891,11 @@ def _thinker_to_talker_prefill( Returns: (input_ids, input_embeds) for talker """ + target_len = thinker_result_ids.shape[-1] im_start_indexes = torch.cat( ( torch.nonzero(input_ids[0] == self.config.im_start_token_id).squeeze(), - torch.tensor([thinker_result_ids.shape[-1]], device=input_ids.device, dtype=input_ids.dtype), + torch.tensor([target_len], device=input_ids.device, dtype=input_ids.dtype), ), dim=-1, ) # Shape [n_starts + 1]; Take batch 0 since batched inference is not supported here. @@ -1029,8 +1030,35 @@ def talker_preprocess_decode( return last_talker_hidden, text_step, update_dict def _get_talker_user_parts(self, im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed): + clamped = min( + segment_end_index, + multimodal_mask.shape[0], + thinker_hidden.shape[0], + thinker_embed.shape[0], + ) + if clamped < segment_end_index: + logger.warning( + "_get_talker_user_parts: segment_end_index %d clamped to %d " + "(embed=%d, hidden=%d, mask=%d). " + "This usually means _merge_pd_embeddings failed to merge " + "prefill embeddings – check PD prefill_mm keys.", + segment_end_index, + clamped, + thinker_embed.shape[0], + thinker_hidden.shape[0], + multimodal_mask.shape[0], + ) + segment_end_index = clamped + seg_len = segment_end_index - im_start_index + if seg_len <= 0: + return torch.empty( + (0, self.config.talker_config.text_config.hidden_size), + device=thinker_hidden.device, + dtype=torch.bfloat16, + ) + user_talker_part = torch.empty( - (segment_end_index - im_start_index, self.config.talker_config.text_config.hidden_size), + (seg_len, self.config.talker_config.text_config.hidden_size), device=thinker_hidden.device, dtype=torch.bfloat16, ) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index f4828fddaa5..c502041fe20 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -3,6 +3,7 @@ # Copyright 2025 The Qwen team. """Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition.""" +import logging from typing import Any import torch @@ -18,6 +19,12 @@ extract_speaker_from_request, ) +logger = logging.getLogger(__name__) + +# Pooling output layer keys: "0" = word embedding, "24" = accept_hidden_layer +_EMBED_LAYER_KEY = "0" +_HIDDEN_LAYER_KEY = "24" + def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int: im_start_token_id = 151644 @@ -84,6 +91,95 @@ def _validate_stage_inputs(stage_list, engine_input_source): return stage.engine_outputs +# ========================= +# PD disaggregation helpers +# ========================= + + +def _get_prefill_stage(stage_list: list[Any], source_stage_id: int) -> Any | None: + if source_stage_id <= 0: + return None + source_stage = stage_list[source_stage_id] + if not getattr(source_stage, "is_decode_only", False): + return None + prev_stage = stage_list[source_stage_id - 1] + if getattr(prev_stage, "is_prefill_only", False) and prev_stage.engine_outputs is not None: + return prev_stage + return None + + +def _merge_pd_embeddings( + decode_emb: torch.Tensor, + decode_hid: torch.Tensor, + prefill_mm: dict[str, Any], + device: torch.device, + expected_total: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Merge prefill prompt embeddings with decode generated embeddings. + + In PD mode the prefill engine processes the prompt and the decode engine + generates tokens starting from position 1. This function concatenates + them, removing the overlapping token(s): + + merged = prefill[:P] + decode[overlap:] + + where overlap = P + D - expected_total. + """ + try: + p_emb = prefill_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float) + p_hid = prefill_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float) + except (KeyError, AttributeError, TypeError) as exc: + available_keys = list(prefill_mm.keys()) if isinstance(prefill_mm, dict) else type(prefill_mm).__name__ + logger.error( + "_merge_pd_embeddings: failed to extract prefill embeddings (%s). " + "Expected keys %r and %r, got: %s. " + "Falling back to decode-only embeddings – talker user-segment will be degraded.", + exc, + _EMBED_LAYER_KEY, + _HIDDEN_LAYER_KEY, + available_keys, + ) + return decode_emb, decode_hid + + if p_emb.shape[0] == 0 or decode_emb.shape[0] == 0: + return decode_emb, decode_hid + + raw_total = p_emb.shape[0] + decode_emb.shape[0] + overlap = max(0, raw_total - expected_total) if expected_total is not None else 0 + + merged_emb = torch.cat([p_emb, decode_emb[overlap:]], dim=0) + merged_hid = torch.cat([p_hid, decode_hid[overlap:]], dim=0) + return merged_emb, merged_hid + + +def _get_prefill_multimodal_output(prefill_stage: Any, output_index: int) -> dict[str, Any] | None: + """Return multimodal_output dict from the PD prefill stage for a given batch index.""" + try: + prefill_eos = prefill_stage.engine_outputs + prefill_eo = prefill_eos[min(output_index, len(prefill_eos) - 1)] + return prefill_eo.outputs[0].multimodal_output + except Exception: + return None + + +def _resolve_tts_token_embedding( + key: str, + *, + thinker_mm: dict[str, Any], + prefill_mm: dict[str, Any] | None, + device: torch.device, +) -> torch.Tensor | None: + """Return TTS BOS/EOS/PAD embedding tensors for the talker projection path. + + Values are taken from the current thinker (decode) ``multimodal_output``; in + PD mode, missing keys may be filled from the paired prefill stage output. + """ + val = thinker_mm.get(key) + if val is None and prefill_mm is not None: + val = prefill_mm.get(key) + return val.detach().to(device=device, dtype=torch.float) if val is not None else None + + # ========================= # Thinker -> Talker # ========================= @@ -111,8 +207,8 @@ def thinker2talker_async_chunk( all_token_ids = _ensure_list(all_token_ids) prompt_token_ids = _ensure_list(prompt_token_ids) talker_additional_info = { - "thinker_prefill_embeddings": pooling_output.get("0").detach().cpu(), - "thinker_hidden_states": pooling_output.get("24").detach().cpu(), + "thinker_prefill_embeddings": pooling_output.get(_EMBED_LAYER_KEY).detach().cpu(), + "thinker_hidden_states": pooling_output.get(_HIDDEN_LAYER_KEY).detach().cpu(), "thinker_sequences": all_token_ids, "thinker_input_ids": prompt_token_ids, # Provide thinker-side TTS token embeddings for talker projection @@ -161,7 +257,7 @@ def thinker2talker_async_chunk( if output_token_ids: talker_additional_info["override_keys"] = ["thinker_decode_embeddings", "thinker_output_token_ids"] - talker_additional_info["thinker_decode_embeddings"] = pooling_output.get("0").detach().cpu() + talker_additional_info["thinker_decode_embeddings"] = pooling_output.get(_EMBED_LAYER_KEY).detach().cpu() talker_additional_info["thinker_output_token_ids"] = output_token_ids else: # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. @@ -185,6 +281,9 @@ def thinker2talker( 2. Split hidden states into: prompt embeddings + generated embeddings 3. Package for talker with additional information + In PD disaggregation mode, merges prefill-stage prompt embeddings with + decode-stage generated embeddings before handing off to the talker. + Args: stage_list: List of stage objects engine_input_source: Source stage IDs (typically [0] for thinker) @@ -199,21 +298,49 @@ def thinker2talker( device = torch.device(current_platform.device_type) + # PD disaggregation: look up the preceding prefill stage (if any) + source_stage_id = engine_input_source[0] + prefill_stage = _get_prefill_stage(stage_list, source_stage_id) + # Process each thinker output for i, thinker_output in enumerate(thinker_outputs): output = thinker_output.outputs[0] + thinker_mm = output.multimodal_output + # Full thinker embedding sequence for the talker: single thinker engine in the + # non-PD path; after optional merge with prefill-side tensors in PD mode. + thinker_emb = thinker_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float) + thinker_hid = thinker_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float) + + prefill_mm: dict[str, Any] | None = None + if prefill_stage is not None: + prefill_mm = _get_prefill_multimodal_output(prefill_stage, i) + + if prefill_mm is not None: + expected_total = len(thinker_output.prompt_token_ids) + len(output.token_ids) + try: + thinker_emb, thinker_hid = _merge_pd_embeddings( + thinker_emb, thinker_hid, prefill_mm, device, expected_total=expected_total + ) + except Exception as exc: + logger.warning("[PD] Could not merge prefill embeddings: %s", exc) info = { - "thinker_prefill_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float), - "thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float), + "thinker_prefill_embeddings": thinker_emb, + "thinker_hidden_states": thinker_hid, "thinker_sequences": ( thinker_output.prompt_token_ids + output.token_ids ), # the thinker_sequences is the whole ids "thinker_input_ids": thinker_output.prompt_token_ids, # Provide thinker-side TTS token embeddings for talker projection - "tts_bos_embed": output.multimodal_output["tts_bos_embed"].detach().to(device=device, dtype=torch.float), - "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float), - "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float), + "tts_bos_embed": _resolve_tts_token_embedding( + "tts_bos_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device + ), + "tts_eos_embed": _resolve_tts_token_embedding( + "tts_eos_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device + ), + "tts_pad_embed": _resolve_tts_token_embedding( + "tts_pad_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device + ), } speaker = extract_speaker_from_prompt(prompt, index=i) if speaker is not None: From c0ccbb872f018851f2c5a6c168e8175eeb10704e Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Fri, 17 Apr 2026 19:10:02 +0800 Subject: [PATCH 02/38] [Model] Add Ming-flash-omni-2.0 Thinker Stage (#1822) Signed-off-by: yuanheng Signed-off-by: Yuanheng Zhao --- .../ming_flash_omni/README.md | 76 ++ .../ming_flash_omni/end2end.py | 485 ++++++++++ .../online_serving/ming_flash_omni/README.md | 204 ++++ .../run_curl_multimodal_generation.sh | 145 +++ tests/conftest.py | 4 + .../offline_inference/test_ming_flash_omni.py | 142 +++ .../online_serving/test_ming_flash_omni.py | 247 +++++ .../bailingmm_moe_v2_lite_ci.yaml | 35 + .../models/ming_flash_omni/__init__.py | 18 + .../models/ming_flash_omni/audio_encoder.py | 246 +++++ .../models/ming_flash_omni/ming_flash_omni.py | 223 +++++ .../ming_flash_omni_thinker.py | 893 +++++++++++++++++ .../modeling_bailing_moe_v2.py | 896 ++++++++++++++++++ .../models/ming_flash_omni/projectors.py | 184 ++++ .../models/ming_flash_omni/vision_encoder.py | 125 +++ .../qwen3_tts/tokenizer_25hz/vq/speech_vq.py | 3 +- .../tokenizer_25hz/vq/whisper_encoder.py | 25 +- vllm_omni/model_executor/models/registry.py | 17 + .../model_executor/models/whisper_utils.py | 39 + .../stage_configs/bailingmm_moe_v2_lite.yaml | 46 + .../transformers_utils/configs/__init__.py | 11 + .../configs/ming_flash_omni.py | 302 ++++++ .../transformers_utils/processors/__init__.py | 12 + .../transformers_utils/processors/ming.py | 430 +++++++++ 24 files changed, 4783 insertions(+), 25 deletions(-) create mode 100644 examples/offline_inference/ming_flash_omni/README.md create mode 100644 examples/offline_inference/ming_flash_omni/end2end.py create mode 100644 examples/online_serving/ming_flash_omni/README.md create mode 100755 examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh create mode 100644 tests/e2e/offline_inference/test_ming_flash_omni.py create mode 100644 tests/e2e/online_serving/test_ming_flash_omni.py create mode 100644 tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml create mode 100644 vllm_omni/model_executor/models/ming_flash_omni/__init__.py create mode 100644 vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py create mode 100644 vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py create mode 100644 vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py create mode 100644 vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py create mode 100644 vllm_omni/model_executor/models/ming_flash_omni/projectors.py create mode 100644 vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py create mode 100644 vllm_omni/model_executor/models/whisper_utils.py create mode 100644 vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml create mode 100644 vllm_omni/transformers_utils/configs/ming_flash_omni.py create mode 100644 vllm_omni/transformers_utils/processors/__init__.py create mode 100644 vllm_omni/transformers_utils/processors/ming.py diff --git a/examples/offline_inference/ming_flash_omni/README.md b/examples/offline_inference/ming_flash_omni/README.md new file mode 100644 index 00000000000..7414163fc01 --- /dev/null +++ b/examples/offline_inference/ming_flash_omni/README.md @@ -0,0 +1,76 @@ +# Ming-flash-omni 2.0 + +[Ming-flash-omni-2.0](https://github.com/inclusionAI/Ming) is an omni-modal model supporting text, image, video, and audio understanding, with outputs in text, image, and audio. For now, Ming-flash-omni-2.0 in vLLM-Omni is supported with thinker stage (multi-modal understanding). + +## Setup + +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Run examples + +### Text-only +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type text +``` + +#### Reasoning (Thinking Mode) + +Reasoning (Thinking) mode is enabled via applying "detailed thinking on" when building the system prompt template (in `apply_chat_template`). + +In the end2end example, a default problem for thinking mode is provided, as referred to the example usage of Ming's cookbook; +To utilize it, you have to download the example figure from https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png + +```bash +python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png +``` + +### Image understanding +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image + +# With a local image +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image --image-path /path/to/image.jpg +``` + +### Audio understanding +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio + +# With a local audio file +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --audio-path /path/to/audio.wav +``` + +### Video understanding +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video + +# With a local video and custom frame count +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video --video-path /path/to/video.mp4 --num-frames 16 +``` + +### Mixed modalities (image + audio) +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_mixed_modalities \ + --image-path /path/to/image.jpg \ + --audio-path /path/to/audio.wav +``` + +If media file paths are not provided, the script uses built-in default assets. + +### Modality control +To control output modalities (e.g. text-only output): +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --modalities text +``` + +*For now, only text output is supported* + +### Custom stage config +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image \ + --stage-configs-path /path/to/your_config.yaml +``` + +## Online serving + +For online serving via the OpenAI-compatible API, see [examples/online_serving/ming_flash_omni/README.md](../../online_serving/ming_flash_omni/README.md). diff --git a/examples/offline_inference/ming_flash_omni/end2end.py b/examples/offline_inference/ming_flash_omni/end2end.py new file mode 100644 index 00000000000..49cdbcc0186 --- /dev/null +++ b/examples/offline_inference/ming_flash_omni/end2end.py @@ -0,0 +1,485 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Partial example cases are referred from +# https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/cookbook.ipynb +import os +import time +from typing import NamedTuple + +import librosa +import numpy as np +import vllm +from PIL import Image +from transformers import AutoProcessor +from vllm import SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset, video_to_ndarrays +from vllm.multimodal.image import convert_image_mode +from vllm.utils.argparse_utils import FlexibleArgumentParser + +import vllm_omni +from vllm_omni.entrypoints.omni import Omni + +# Imports the processor also registers itself +from vllm_omni.transformers_utils.processors.ming import MingFlashOmniProcessor # noqa: F401 + +SEED = 42 +MODEL_NAME = "Jonathan1909/Ming-flash-omni-2.0" + + +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] + + +def get_text_query(processor: MingFlashOmniProcessor, question: str | None = None) -> QueryResult: + if question is None: + question = "请详细介绍鹦鹉的生活习性。" + conversation = [{"role": "HUMAN", "content": question}] + prompt = processor.apply_chat_template(conversation, tokenize=False) + return QueryResult( + inputs={"prompt": prompt}, + limit_mm_per_prompt={}, + ) + + +def get_image_query( + processor: MingFlashOmniProcessor, + question: str | None = None, + image_path: str | None = None, +) -> QueryResult: + if question is None: + question = "Describe this image in detail." + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + image_data = convert_image_mode(Image.open(image_path), "RGB") + else: + image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": image_data}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"image": image_data}, + }, + limit_mm_per_prompt={"image": 1}, + ) + + +def get_audio_query( + processor: MingFlashOmniProcessor, + question: str | None = None, + audio_path: str | None = None, + sampling_rate: int = 16000, +) -> QueryResult: + if question is None: + question = "Please recognize the language of this speech and transcribe it. Format: oral." + + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (audio_signal.astype(np.float32), sr) + else: + audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate + + # Use a string for "audio" so the processor counts it as 1 audio input + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "audio", "audio": "input"}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"audio": audio_data}, + }, + limit_mm_per_prompt={"audio": 1}, + ) + + +def get_video_query( + processor: MingFlashOmniProcessor, + question: str | None = None, + video_path: str | None = None, + num_frames: int = 16, +) -> QueryResult: + if question is None: + question = "Describe what is happening in this video." + + if video_path: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + video_frames = video_to_ndarrays(video_path, num_frames=num_frames) + else: + video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays + + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "video"}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"video": video_frames}, + }, + limit_mm_per_prompt={"video": 1}, + ) + + +def get_mixed_modalities_query( + processor: MingFlashOmniProcessor, + image_path: str | None = None, + audio_path: str | None = None, + sampling_rate: int = 16000, +) -> QueryResult: + """Mixed image + audio understanding.""" + question = "Describe the image, and recognize the language of this speech and transcribe it. Format: oral" + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + image_data = convert_image_mode(Image.open(image_path), "RGB") + else: + image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + sig, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (sig.astype(np.float32), sr) + else: + audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate + + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": image_data}, + {"type": "audio", "audio": "input"}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"image": image_data, "audio": audio_data}, + }, + limit_mm_per_prompt={"image": 1, "audio": 1}, + ) + + +def get_reasoning_query( + processor: MingFlashOmniProcessor, + question: str | None = None, + image_path: str | None = None, +) -> QueryResult: + if question is None: + # NOTE: To use the following default question, input with example figure provided by Ming + # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png + # E.g., + # python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png + # Otherwise, the problem solving might be false. + question = ( + "Based on the following rules:\n•\tYou control the smiley face character\n" + "•\tYou can move up, down, left, and right, and only a single square at a time\n" + "•\tWalls are dark grey and cannot be moved into\n•\tThe brown square is a box\n•" + "\tThe box can be pushed by moving into it (i.e., if you are in the square " + "adjacent to the box to the left, and move onto the square with the box, " + "the box will move one square to the right).\n" + "•\tThe box cannot be pushed into walls\n" + "•\tThe blue door at the bottom is locked and cannot be passed through, " + "unless the box is placed on the blue square\n" + "•\tThe square beneath the blue door is the exit\n" + "•\tMoving from one square to another\n\n" + "Let's assume a coordinate system where the smiley face is " + "on the top left at (1,1) and the square below it is (1,2). " + "The smiley face performs the following moves: {down, right, right, right}, " + "such that the smiley face is at square (4,2) and the box is in square (5,2). " + "What are the next sequence of moves that must be done to move the box down to (5,3)? " + "Give your answer as a comma separated list." + ) + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + image_data = convert_image_mode(Image.open(image_path), "RGB") + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": image_data}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True) + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"image": image_data}, + }, + limit_mm_per_prompt={"image": 1}, + ) + + conversation = [{"role": "HUMAN", "content": question}] + prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True) + return QueryResult( + inputs={"prompt": prompt}, + limit_mm_per_prompt={}, + ) + + +query_map = { + "text": get_text_query, + "use_audio": get_audio_query, + "use_image": get_image_query, + "use_video": get_video_query, + "use_mixed_modalities": get_mixed_modalities_query, + "reasoning": get_reasoning_query, +} + + +def main(args): + print( + "=" * 20, + "\n", + f"vllm version: {vllm.__version__}\n", + f"vllm-omni version: {vllm_omni.__version__}\n", + "=" * 20, + sep="", + ) + + processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) + assert isinstance(processor, MingFlashOmniProcessor), f"Wrong processor type being used: {type(processor)}" + + query_func = query_map[args.query_type] + if args.query_type == "use_image": + query_result = query_func(processor, image_path=args.image_path) + elif args.query_type == "use_audio": + query_result = query_func(processor, audio_path=args.audio_path, sampling_rate=args.sampling_rate) + elif args.query_type == "use_video": + query_result = query_func(processor, video_path=args.video_path, num_frames=args.num_frames) + elif args.query_type == "use_mixed_modalities": + query_result = query_func( + processor, + image_path=args.image_path, + audio_path=args.audio_path, + sampling_rate=args.sampling_rate, + ) + elif args.query_type == "reasoning": + query_result = query_func(processor, image_path=args.image_path) + else: + query_result = query_func(processor) + + # Initialize Omni (with thinker-only stage config) + omni = Omni( + model=MODEL_NAME, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + init_timeout=args.init_timeout, + stage_init_timeout=args.stage_init_timeout, + ) + + # Thinker sampling params + thinker_sampling_params = SamplingParams( + temperature=0.4, + top_p=0.9, + max_tokens=args.max_tokens, + repetition_penalty=1.05, + seed=SEED, + detokenize=True, + ) + sampling_params_list = [thinker_sampling_params] + + prompts = [query_result.inputs for _ in range(args.num_prompts)] + + if args.modalities is not None: + output_modalities = args.modalities.split(",") + for prompt in prompts: + prompt["modalities"] = output_modalities + + total_requests = len(prompts) + processed_count = 0 + print(f"Query type: {args.query_type}") + print(f"Number of prompts: {total_requests}") + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + profiler_enabled = args.enable_profiler + if profiler_enabled: + omni.start_profile(stages=args.profiler_stages) + + for stage_outputs in omni.generate(prompts, sampling_params_list): + output = stage_outputs.request_output + if stage_outputs.final_output_type == "text": + request_id = output.request_id + text_output = output.outputs[0].text + lines = [] + lines.append("Prompt:\n") + lines.append(str(output.prompt) + "\n") + lines.append("Text Output:\n") + lines.append(str(text_output).strip() + "\n") + print(*lines, sep="") + + # Save to file + out_txt = os.path.join(output_dir, f"{request_id}.txt") + try: + with open(out_txt, "w", encoding="utf-8") as f: + f.writelines(lines) + print(f"Request ID: {request_id}, text saved to {out_txt}") + except Exception as e: + print(f"Failed to write output file {out_txt}: {e}") + + elif stage_outputs.final_output_type == "audio": + raise NotImplementedError("Add audio example after talker supported.") + + processed_count += 1 + if profiler_enabled and processed_count >= total_requests: + print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") + # Stop the profiler while workers are still alive + omni.stop_profile(stages=args.profiler_stages) + + print("[Info] Waiting 30s for workers to write trace files to disk...") + time.sleep(30) + print("[Info] Trace export wait time finished.") + + omni.close() + + +def parse_args(): + parser = FlexibleArgumentParser(description="Ming-flash-omni 2.0 offline inference example") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="text", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--stage-configs-path", + type=str, + default=None, + help="Path to a stage configs YAML file.", + ) + parser.add_argument( + "--log-stats", + action="store_true", + default=False, + help="Enable detailed statistics logging.", + ) + parser.add_argument("--init-timeout", type=int, default=2000, help="Timeout for initializing in seconds.") + parser.add_argument( + "--stage-init-timeout", + type=int, + default=2000, + help="Timeout for initializing a single stage in seconds.", + ) + parser.add_argument( + "--enable-profiler", + action="store_true", + default=False, + help="Enables profiling when set.", + ) + parser.add_argument( + "--profiler-stages", + type=int, + nargs="*", + default=[0], + help="List of stage IDs to profile. If not set, profiles all stages.", + ) + parser.add_argument( + "--image-path", + "-i", + type=str, + default=None, + help="Path to local image file. Uses default asset if not provided.", + ) + parser.add_argument( + "--audio-path", + "-a", + type=str, + default=None, + help="Path to local audio file. Uses default asset if not provided.", + ) + parser.add_argument( + "--video-path", + "-v", + type=str, + default=None, + help="Path to local video file. Uses default asset if not provided.", + ) + parser.add_argument( + "--num-frames", + type=int, + default=16, + help="Number of frames to extract from video.", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=16000, + help="Sampling rate for audio loading.", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=16384, + help="Maximum tokens to generate.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1, + help="Number of prompts to generate.", + ) + parser.add_argument( + "--modalities", + type=str, + default=None, + help="Output modalities (comma-separated).", + ) + parser.add_argument( + "--output-dir", + type=str, + default="output_ming", + help="Output directory for results.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/ming_flash_omni/README.md b/examples/online_serving/ming_flash_omni/README.md new file mode 100644 index 00000000000..502232725c2 --- /dev/null +++ b/examples/online_serving/ming_flash_omni/README.md @@ -0,0 +1,204 @@ +# Ming-flash-omni 2.0 + +## Installation + +Please refer to [README.md](../../../README.md) + +## Run examples (Ming-flash-omni 2.0) + +### Launch the Server + +```bash +vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 +``` + +If you have custom stage configs file, launch the server with command below +```bash +vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +### Send Multi-modal Request + +#### Send request via python + +```bash +python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py --model Jonathan1909/Ming-flash-omni-2.0 --query-type use_mixed_modalities --port 8091 --host "localhost" --modalities text +``` + +The Python client supports the following command-line arguments: + +- `--query-type` (or `-q`): Query type. Options: `text`, `use_audio`, `use_image`, `use_video`, `use_mixed_modalities` +- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type uses video, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4` +- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type uses image, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png` +- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3` +- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"` +- `--modalities`: Output modalities. For now, only `text` is supported. Example: `--modalities text` + + +#### Send request via curl + +```bash +bash run_curl_multimodal_generation.sh text +bash run_curl_multimodal_generation.sh use_image +bash run_curl_multimodal_generation.sh use_audio +bash run_curl_multimodal_generation.sh use_video +bash run_curl_multimodal_generation.sh use_mixed_modalities +``` + +## Modality control + +Ming-flash-omni 2.0 currently supports text output only (thinker stage). + +| Modalities | Output | +|------------|--------| +| `["text"]` | Text only | +| Not specified | Text only (default) | + +### Using curl + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"} + ], + "modalities": ["text"] + }' +``` + +### Using OpenAI Python SDK + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Jonathan1909/Ming-flash-omni-2.0", + messages=[ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}, + ], + modalities=["text"], +) +print(response.choices[0].message.content) +``` + +### Multi-modal input with OpenAI Python SDK + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Jonathan1909/Ming-flash-omni-2.0", + messages=[ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"}}, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ], + modalities=["text"], +) +print(response.choices[0].message.content) +``` + +## Streaming Output + +To enable streaming output: + +```bash +python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --model Jonathan1909/Ming-flash-omni-2.0 \ + --modalities text \ + --stream +``` + +Or with the OpenAI Python SDK: + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Jonathan1909/Ming-flash-omni-2.0", + messages=[ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}, + ], + modalities=["text"], + stream=True, +) +for chunk in response: + for choice in chunk.choices: + if hasattr(choice, "delta") and choice.delta.content: + print(choice.delta.content, end="", flush=True) +print() +``` + +Or using curl: + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"} + ], + "modalities": ["text"], + "stream": true, + }' +``` + + +## Reasoning (Thinking Mode) + +To enable reasoning/thinking mode, change `detailed thinking off` to `detailed thinking on` in the system prompt: + +### Using curl + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "https://example.com/math_problem.png"}}, + {"type": "text", "text": "Solve this math problem step by step."} + ]} + ], + "modalities": ["text"] + }' +``` + +### Using OpenAI Python SDK + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Jonathan1909/Ming-flash-omni-2.0", + messages=[ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]}, + {"role": "user", "content": "If a train travels 120 km in 2 hours, what is its average speed?"}, + ], + modalities=["text"], +) +print(response.choices[0].message.content) +``` diff --git a/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh new file mode 100755 index 00000000000..768a424e451 --- /dev/null +++ b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Server port +PORT="${PORT:-8091}" +# Default query type +QUERY_TYPE="${1:-text}" + +# Validate query type +if [[ ! "$QUERY_TYPE" =~ ^(text|use_audio|use_image|use_video|use_mixed_modalities)$ ]]; then + echo "Error: Invalid query type '$QUERY_TYPE'" + echo "Usage: $0 [text|use_audio|use_image|use_video|use_mixed_modalities]" + echo " text: Text-only query" + echo " use_audio: Audio + Text query" + echo " use_image: Image + Text query" + echo " use_video: Video + Text query" + echo " use_mixed_modalities: Audio + Image + Video + Text query" + exit 1 +fi + +thinker_sampling_params='{ + "temperature": 0.4, + "top_p": 0.9, + "top_k": -1, + "max_tokens": 16384, + "seed": 42, + "detokenize": true, + "repetition_penalty": 1.05 +}' +# Above is optional, it has a default setting in stage_configs of the corresponding model. + +# Define URLs for assets +MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg" +CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" +SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + +# Build user content based on query type +case "$QUERY_TYPE" in + text) + user_content='[ + { + "type": "text", + "text": "请详细介绍鹦鹉的生活习性。" + } + ]' + ;; + use_image) + user_content='[ + { + "type": "image_url", + "image_url": { + "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'" + } + }, + { + "type": "text", + "text": "Describe this image in detail." + } + ]' + ;; + use_audio) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "text", + "text": "Please recognize the language of this speech and transcribe it. Format: oral." + } + ]' + ;; + use_video) + user_content='[ + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "Describe what is happening in this video." + } + ]' + ;; + use_mixed_modalities) + user_content='[ + { + "type": "image_url", + "image_url": { + "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'" + } + }, + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "text", + "text": "Describe the image, and recognize the language of this speech and transcribe it. Format: oral" + } + ]' + ;; +esac + +echo "Running query type: $QUERY_TYPE" +echo "" + +request_body=$(cat < str: + """Build a Ming chat prompt.""" + return ( + f"SYSTEM{SYSTEM_PROMPT}{EOS_TOKEN}HUMAN{user_text}{EOS_TOKEN}ASSISTANT" + ) + + +def get_eager_config(): + path = modify_stage_config( + str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"), + updates={ + "stage_args": { + 0: { + "engine_args.enforce_eager": "true", + }, + }, + }, + ) + return path + + +stage_configs = [get_eager_config()] +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_text_to_text(omni_runner, omni_runner_handler) -> None: + """ + Test text-only input processing and text output generation. + Input Modal: text + Output Modal: text + """ + prompt = build_prompt("请详细介绍鹦鹉的生活习性。") + request_config = {"prompts": prompt, "modalities": ["text"]} + + omni_runner_handler.send_request(request_config) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_image_to_text(omni_runner, omni_runner_handler) -> None: + """ + Test image understanding with text output. + Input Modal: image + text + Output Modal: text + """ + image = generate_synthetic_image(224, 224)["np_array"] + prompt = build_prompt(f"{IMAGE_TOKEN}Describe this image briefly.") + request_config = {"prompts": prompt, "images": image, "modalities": ["text"]} + + omni_runner_handler.send_request(request_config) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_audio_to_text(omni_runner, omni_runner_handler) -> None: + """ + Test audio understanding with text output. + Input Modal: audio + text + Output Modal: text + """ + audio = generate_synthetic_audio(2, 1, 16000)["np_array"] + if len(audio.shape) == 2: + audio = audio.squeeze() + prompt = build_prompt(f"{AUDIO_TOKEN}Please recognize the language of this speech and transcribe it. Format: oral.") + request_config = {"prompts": prompt, "audios": audio, "modalities": ["text"]} + + omni_runner_handler.send_request(request_config) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_video_to_text(omni_runner, omni_runner_handler) -> None: + """ + Test video understanding with text output. + Input Modal: video + text + Output Modal: text + """ + video = generate_synthetic_video(224, 224, 30)["np_array"] + prompt = build_prompt(f"{VIDEO_TOKEN}Describe what is happening in this video.") + request_config = {"prompts": prompt, "videos": video, "modalities": ["text"]} + + omni_runner_handler.send_request(request_config) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_mixed_to_text(omni_runner, omni_runner_handler) -> None: + """ + Test mixed modality input (image + audio) with text output. + Input Modal: image + audio + text + Output Modal: text + """ + image = generate_synthetic_image(224, 224)["np_array"] + audio = generate_synthetic_audio(2, 1, 16000)["np_array"] + if len(audio.shape) == 2: + audio = audio.squeeze() + prompt = build_prompt(f"{IMAGE_TOKEN}{AUDIO_TOKEN}Describe the image and transcribe the audio.") + request_config = {"prompts": prompt, "images": image, "audios": audio, "modalities": ["text"]} + + omni_runner_handler.send_request(request_config) diff --git a/tests/e2e/online_serving/test_ming_flash_omni.py b/tests/e2e/online_serving/test_ming_flash_omni.py new file mode 100644 index 00000000000..35b7b64c061 --- /dev/null +++ b/tests/e2e/online_serving/test_ming_flash_omni.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E online serving tests for Ming-flash-omni-2.0 model (Thinker stage). +Tests multimodal understanding via OpenAI-compatible API. +""" + +import os +from pathlib import Path + +import pytest + +from tests.conftest import ( + OmniServerParams, + dummy_messages_from_mix_data, + generate_synthetic_audio, + generate_synthetic_image, + generate_synthetic_video, + modify_stage_config, +) +from tests.utils import hardware_test + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +models = ["Jonathan1909/Ming-flash-omni-2.0"] + + +def get_eager_config(): + path = modify_stage_config( + str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"), + updates={ + "stage_args": { + 0: { + "engine_args.enforce_eager": "true", + }, + }, + }, + ) + return path + + +stage_configs = [get_eager_config()] + +# Create parameter combinations for model and stage config +test_params = [ + OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs +] + + +def get_system_prompt(): + return { + "role": "system", + "content": [ + { + "type": "text", + "text": "你是一个友好的AI助手。\n\ndetailed thinking off", + } + ], + } + + +def get_prompt(prompt_type="text_only"): + prompts = { + "text_only": "What is the capital of China? Answer in 20 words.", + "text_image": "What is in this image?", + "text_audio": "What is in this audio?", + "text_video": "What is in this video?", + "mix": "What is recited in the audio? What is in this image? What is in this video?", + } + return prompts.get(prompt_type, prompts["text_only"]) + + +def get_max_batch_size(size_type="few"): + batch_sizes = {"few": 5, "medium": 100, "large": 256} + return batch_sizes.get(size_type, 5) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_text_to_text_001(omni_server, openai_client) -> None: + """ + Input Modal: text + Output Modal: text + Input Setting: stream=False + Datasets: single request + """ + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + content_text=get_prompt("text_only"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": False, + "modalities": ["text"], + "key_words": {"text": ["beijing"]}, + } + + openai_client.send_omni_request(request_config) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_text_to_text_stream_001(omni_server, openai_client) -> None: + """ + Input Modal: text + Output Modal: text + Input Setting: stream=True + Datasets: few requests + """ + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + content_text=get_prompt("text_only"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": True, + "modalities": ["text"], + "key_words": {"text": ["beijing"]}, + } + + openai_client.send_omni_request(request_config, request_num=get_max_batch_size()) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_image_to_text_001(omni_server, openai_client) -> None: + """ + Input Modal: image + text + Output Modal: text + Input Setting: stream=True + Datasets: single request + """ + image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}" + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + image_data_url=image_data_url, + content_text=get_prompt("text_image"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": True, + "modalities": ["text"], + } + + openai_client.send_omni_request(request_config) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_audio_to_text_001(omni_server, openai_client) -> None: + """ + Input Modal: audio + text + Output Modal: text + Input Setting: stream=True + Datasets: single request + """ + audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(2, 1)['base64']}" + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + audio_data_url=audio_data_url, + content_text=get_prompt("text_audio"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": True, + "modalities": ["text"], + } + + openai_client.send_omni_request(request_config) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_video_to_text_001(omni_server, openai_client) -> None: + """ + Input Modal: video + text + Output Modal: text + Input Setting: stream=False + Datasets: single request + """ + video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}" + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + video_data_url=video_data_url, + content_text=get_prompt("text_video"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": False, + "modalities": ["text"], + } + + openai_client.send_omni_request(request_config) + + +@pytest.mark.advanced_model +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_mix_to_text_001(omni_server, openai_client) -> None: + """ + Input Modal: text + audio + image + video + Output Modal: text + Input Setting: stream=True + Datasets: single request + """ + video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}" + image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}" + audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(2, 1)['base64']}" + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + video_data_url=video_data_url, + image_data_url=image_data_url, + audio_data_url=audio_data_url, + content_text=get_prompt("mix"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": True, + "modalities": ["text"], + } + + openai_client.send_omni_request(request_config) diff --git a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml new file mode 100644 index 00000000000..fb0c72cc513 --- /dev/null +++ b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml @@ -0,0 +1,35 @@ +# Thinker stage only +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0,1,2,3" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: MingFlashOmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + max_model_len: 32768 + tensor_parallel_size: 4 + hf_config_name: llm_config + load_format: dummy + mm_processor_cache_gb: 0 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + max_tokens: 100 + repetition_penalty: 1.05 + seed: 42 + detokenize: true + ignore_eos: false diff --git a/vllm_omni/model_executor/models/ming_flash_omni/__init__.py b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py new file mode 100644 index 00000000000..d7fa44fd7e4 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. + +from .ming_flash_omni import MingFlashOmniForConditionalGeneration +from .ming_flash_omni_thinker import ( + MingFlashOmniThinkerDummyInputsBuilder, + MingFlashOmniThinkerForConditionalGeneration, + MingFlashOmniThinkerMultiModalProcessor, + MingFlashOmniThinkerProcessingInfo, +) + +__all__ = [ + "MingFlashOmniForConditionalGeneration", + "MingFlashOmniThinkerForConditionalGeneration", + "MingFlashOmniThinkerProcessingInfo", + "MingFlashOmniThinkerMultiModalProcessor", + "MingFlashOmniThinkerDummyInputsBuilder", +] diff --git a/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py new file mode 100644 index 00000000000..6ca19901141 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright 2024 ANT Group and the HuggingFace Inc. team. +# Copyright (c) 2022 OpenAI +# Adapted from Ming repository modeling_whisper_encoder.py +# https://github.com/inclusionAI/Ming + +import operator +from collections.abc import Iterable +from itertools import accumulate + +import torch +import torch.nn as nn +import torch.nn.functional as F +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func +from vllm_omni.model_executor.models.whisper_utils import Conv1d, Linear, sinusoids + +logger = init_logger(__name__) + + +class MultiHeadAttention(nn.Module): + """Multi-head attention with packed sequence support. + Adapted from Qwen3-TTS WhisperEncoder. + """ + + def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + if use_flash_attn and not HAS_FLASH_ATTN: + logger.warning("flash-attn is not available. Fallback to manual PyTorch version") + self.use_flash_attn = use_flash_attn and HAS_FLASH_ATTN + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + """Forward pass with packed sequence support. + + Args: + x: [total_tokens, n_state] packed sequence + cu_seqlens: [num_seqs + 1] cumulative sequence lengths, e.g. [0, len1, len1+len2, ...] + + Returns: + [total_tokens, n_state] attention output + """ + q = self.query(x) + k = self.key(x) + v = self.value(x) + + n_ctx, n_state = q.shape + head_dim = n_state // self.n_head + + q = q.view(n_ctx, self.n_head, head_dim) + k = k.view(n_ctx, self.n_head, head_dim) + v = v.view(n_ctx, self.n_head, head_dim) + + # Try flash attention varlen + if self.use_flash_attn and cu_seqlens is not None and q.dtype in [torch.float16, torch.bfloat16]: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen) + else: + attn_output = self._manual_attention(q, k, v, cu_seqlens) + + # Reshape back: [T, H, D] -> [T, H*D] + attn_output = attn_output.contiguous().view(n_ctx, n_state) + return self.out(attn_output) + + def _manual_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor + ) -> torch.Tensor: + """Manual attention for variable-length sequences (fallback).""" + _, n_head, head_dim = q.shape + scale = head_dim**-0.5 + + # Unpack sequences and pad to max length + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + batch_size = len(seqlens) + max_seqlen = max(seqlens) + + # Create padded tensors + q_padded = torch.zeros(batch_size, max_seqlen, n_head, head_dim, dtype=q.dtype, device=q.device) + k_padded = torch.zeros_like(q_padded) + v_padded = torch.zeros_like(q_padded) + + # Fill with actual sequences + for i in range(batch_size): + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] + seq_len = seqlens[i] + q_padded[i, :seq_len] = q[start_idx:end_idx] + k_padded[i, :seq_len] = k[start_idx:end_idx] + v_padded[i, :seq_len] = v[start_idx:end_idx] + + # Transpose for attention: [B, H, T, D] + q_padded = q_padded.transpose(1, 2) + k_padded = k_padded.transpose(1, 2) + v_padded = v_padded.transpose(1, 2) + + # Create attention mask for variable lengths: 0 for valid positions, -inf for padding + padding_mask = ( + torch.arange(max_seqlen, device=q.device)[None, :] >= torch.tensor(seqlens, device=q.device)[:, None] + ) + attn_mask = torch.zeros(batch_size, 1, 1, max_seqlen, dtype=q.dtype, device=q.device) + attn_mask = attn_mask.masked_fill(padding_mask.unsqueeze(1).unsqueeze(2), -torch.finfo(q.dtype).max) + + # Compute attention + attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale + attn_scores = attn_scores + attn_mask + attn_weights = F.softmax(attn_scores, dim=-1) + context = torch.matmul(attn_weights, v_padded) + + # Transpose back: [B, H, T, D] -> [B, T, H, D] + context = context.transpose(1, 2).contiguous() + output_packed = torch.cat([context[i, : seqlens[i]] for i in range(batch_size)], dim=0) + + return output_packed + + +class ResidualAttentionBlock(nn.Module): + """Whisper-style residual attention block with packed sequence support. + + Adapted from + https://github.com/openai/whisper/blob/v20250625/whisper/model.py + vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py + """ + + def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True): + super().__init__() + self.attn = MultiHeadAttention(n_state, n_head, use_flash_attn=use_flash_attn) + self.attn_ln = nn.LayerNorm(n_state) + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + Linear(n_state, n_mlp), + nn.GELU(), + Linear(n_mlp, n_state), + ) + self.mlp_ln = nn.LayerNorm(n_state) + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens) + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class WhisperAudioEncoder(nn.Module): + """Whisper audio encoder for Ming with packed sequence support. + + Adapted from + https://github.com/openai/whisper/blob/v20250625/whisper/model.py + vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py + """ + + def __init__( + self, + n_mels: int = 128, + n_ctx: int = 15000, + n_state: int = 1280, + n_head: int = 20, + n_layer: int = 32, + use_flash_attn: bool = True, + ): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + # self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + self.blocks = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head, use_flash_attn=use_flash_attn) for _ in range(n_layer)] + ) + self.ln_post = nn.LayerNorm(n_state) + self.audio_emb_dim = n_state + + self.n_layer = n_layer + self.n_mels = n_mels + self.use_flash_attn = use_flash_attn + + def forward( + self, + x_list: list[torch.Tensor], + audio_lens: list[int], + ) -> torch.Tensor: + """Forward pass with packed sequence format for variable-length inputs. + + Args: + x_list: List of [n_mels, T_i] mel spectrogram features for each audio + audio_lens: List of original audio lengths in frames + + Returns: + [total_T', n_state] packed encoded audio features, where + total_T' is the sum of all encoded sequence lengths + """ + # Cast inputs to model dtype + target_dtype = self.conv1.weight.dtype + x_list = [x.to(target_dtype) for x in x_list] + + encoded_list = [] + encoded_lens = [] + for mel_spec in x_list: + # mel_spec: [n_mels, T] - process through conv layers + x = mel_spec.unsqueeze(0) # [1, n_mels, T] + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.squeeze(0).transpose(0, 1) # [T', n_state] + + # Add positional embedding + seq_len = x.shape[0] + positional_embedding = self.positional_embedding[:seq_len, :] + x = (x + positional_embedding).to(x.dtype) + + encoded_list.append(x) + encoded_lens.append(seq_len) + + x_packed = torch.cat(encoded_list, dim=0) # [total_T', n_state] + + cu_seqlens = list(accumulate(encoded_lens, func=operator.add, initial=0)) + cu_seqlens = torch.tensor(cu_seqlens, device=x_packed.device, dtype=torch.int32) + + for block in self.blocks: + x_packed = block(x_packed, cu_seqlens=cu_seqlens) + + x_packed = self.ln_post(x_packed) + return x_packed + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict: dict[str, torch.Tensor] = { + **dict(self.named_parameters(remove_duplicate=False)), + **dict(self.named_buffers()), + } + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name not in params_dict: + logger.warning("Skipping unknown audio encoder weight: %s", name) + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py new file mode 100644 index 00000000000..87728890b67 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved. +# Adapted from Ming repository modeling_bailingmm2.py +# https://github.com/inclusionAI/Ming +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ming-flash-omni-2.0 unified model (thinker + imagegen + talker).""" + +from collections.abc import Iterable + +import torch +import torch.nn as nn +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import ( + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.utils import ( + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights +from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config, MingFlashOmniConfig + +from .ming_flash_omni_thinker import ( + MingFlashOmniThinkerDummyInputsBuilder, + MingFlashOmniThinkerMultiModalProcessor, + MingFlashOmniThinkerProcessingInfo, +) + +logger = init_logger(__name__) + + +@MULTIMODAL_REGISTRY.register_processor( + MingFlashOmniThinkerMultiModalProcessor, + info=MingFlashOmniThinkerProcessingInfo, + dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder, +) +class MingFlashOmniForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsMRoPE, + CustomProcessMixin, +): + """Unified Ming-flash-omni-2.0 model combining thinker, imagegen, and talker.""" + + supports_multimodal = True + requires_raw_input_tokens: bool = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.have_multimodal_outputs = True + self.has_preprocess = False + self.has_postprocess = False + + config = vllm_config.model_config.hf_config + + self.vllm_config = vllm_config + self.config = config + + if isinstance(config, MingFlashOmniConfig): + thinker_config = config.thinker_config + else: + thinker_config = config + + self.thinker_config: BailingMM2Config = thinker_config + self.model_stage = vllm_config.model_config.model_stage + + if self.model_stage == "thinker": + thinker_vllm_config = vllm_config.with_hf_config( + thinker_config, architectures=["MingFlashOmniThinkerForConditionalGeneration"] + ) + self.thinker = init_vllm_registered_model( + vllm_config=thinker_vllm_config, + prefix=maybe_prefix(prefix, "thinker"), + architectures=["MingFlashOmniThinkerForConditionalGeneration"], + ) + self.model = self.thinker + self.imagegen = None + self.talker = None + + elif self.model_stage == "imagegen": + # TODO: Implement image generator stage + raise NotImplementedError( + "Image generation stage is not yet implemented. Please use model_stage='thinker' for now." + ) + + elif self.model_stage == "talker": + # TODO: Implement talker (TTS) stage + raise NotImplementedError( + "Talker (TTS) stage is not yet implemented. Please use model_stage='thinker' for now." + ) + + else: + raise ValueError( + f"Invalid model_stage: {self.model_stage}. Must be one of: 'thinker', 'imagegen', 'talker'" + ) + + # Set up intermediate tensors + self.make_empty_intermediate_tensors = ( + self.thinker.make_empty_intermediate_tensors if self.model_stage == "thinker" else lambda: None + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> OmniOutput: + return self.model.forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata=None, + ) -> torch.Tensor | None: + if hasattr(self.model, "compute_logits"): + return self.model.compute_logits(hidden_states, sampling_metadata) + return None + + def sample( + self, + logits: torch.Tensor, + sampling_metadata, + ): + if hasattr(self.model, "sample"): + return self.model.sample(logits, sampling_metadata) + raise NotImplementedError("sample method not available on current stage") + + def get_mrope_input_positions(self, *args, **kwargs): + if hasattr(self.model, "get_mrope_input_positions"): + return self.model.get_mrope_input_positions(*args, **kwargs) + raise NotImplementedError("get_mrope_input_positions not available on current stage") + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loaded_weights = set() + thinker_weights = [] + imagegen_weights = [] + talker_weights = [] + + for name, value in weights: + if name.startswith("thinker."): + thinker_weights.append((name, value)) + elif name.startswith("imagegen."): + imagegen_weights.append((name, value)) + elif name.startswith("talker."): + talker_weights.append((name, value)) + else: + # Weights without prefix go to thinker by default + thinker_weights.append((name, value)) + + if self.model_stage == "thinker" and thinker_weights: + # Remove "thinker." prefix before loading + thinker_weights_stripped = [ + (name.replace("thinker.", "", 1) if name.startswith("thinker.") else name, value) + for name, value in thinker_weights + ] + thinker_loaded = self.thinker.load_weights(thinker_weights_stripped) + thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker") + loaded_weights.update(thinker_loaded) + + # TODO: Load imagegen weights when implemented + # TODO: Load talker weights when implemented + + return loaded_weights + + def get_mm_mapping(self) -> MultiModelKeys: + return MultiModelKeys.from_string_field( + language_model="thinker.language_model", + connector=["thinker.linear_proj.", "thinker.linear_proj_audio."], + tower_model=["thinker.vision.", "thinker.audio."], + ) + + @property + def sampler(self): + if hasattr(self.model, "sampler"): + return self.model.sampler + return None + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + *, + is_multimodal=None, + ) -> torch.Tensor: + return self.model.embed_input_ids( + input_ids, + multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + def embed_multimodal(self, **kwargs): + return self.model.embed_multimodal(**kwargs) diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py new file mode 100644 index 00000000000..bde7477b945 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py @@ -0,0 +1,893 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright 2024 ANT Group and the HuggingFace Inc. team. +# Adapted from Ming repository modeling_bailingmm2.py and processing_bailingmm2.py +# https://github.com/inclusionAI/Ming + +"""Ming-flash-omni-2.0 Thinker stage implementation (multimodal understanding).""" + +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import Annotated, Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.feature_extraction_utils import BatchFeature +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs, +) +from vllm.model_executor.models.qwen2_vl import ( + Qwen2VLProcessingInfo, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + ImageProcessorItems, + MultiModalDataItems, + MultiModalDataParser, + VideoProcessorItems, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config +from vllm_omni.transformers_utils.processors.ming import ( + PLACEHOLDER_AUDIO_TOKEN_IN_TEXT, + PLACEHOLDER_IMAGE_TOKEN_IN_TEXT, + PLACEHOLDER_VIDEO_TOKEN_IN_TEXT, + MingFlashOmniProcessor, + MingWhisperFeatureExtractor, +) + +from .audio_encoder import WhisperAudioEncoder +from .modeling_bailing_moe_v2 import BailingMoeV2ForCausalLM +from .projectors import AudioProjector, VisionProjector +from .vision_encoder import MingVisionEncoder + +logger = init_logger(__name__) + + +class MingAudioInput(TensorSchema): + """ + Dimensions: + - b: Batch size + - l: Total audio frames (clips concatenated along the time axis) + - nm: Number of mel bins + - N: Max number of audio clips per batch item + """ + + audio_feats: Annotated[ + torch.Tensor, + TensorShape("b", "l", "nm"), + ] + + audio_feats_lengths: Annotated[ + torch.Tensor, + TensorShape("b", "N"), + ] + + +class MingFlashOmniThinkerProcessingInfo(Qwen2VLProcessingInfo): + def get_hf_config(self) -> BailingMM2Config: + return self.ctx.get_hf_config(BailingMM2Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(MingFlashOmniProcessor, **kwargs) + + def get_target_channels(self) -> int: + # See `_normalize_audio_tensor` in vllm_omni/transformers_utils/processors/ming.py + return 1 + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None, "video": None, "audio": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + mm_counts = mm_counts or {} + requested_modalities = {m for m, c in mm_counts.items() if c > 0} + mm_max_tokens: dict[str, int] = {} + + if requested_modalities & {"image", "video"}: + vl_tokens = super().get_mm_max_tokens_per_item( + seq_len=seq_len, + mm_counts=mm_counts, + ) + mm_max_tokens.update({m: vl_tokens[m] for m in ["image", "video"] if m in requested_modalities}) + + if "audio" in requested_modalities: + # TODO: consider computing from audio config + mm_max_tokens["audio"] = 3000 + + return mm_max_tokens + + def get_feature_extractor(self, **kwargs: object) -> MingWhisperFeatureExtractor: + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.audio_processor + assert isinstance(feature_extractor, MingWhisperFeatureExtractor) + return feature_extractor + + def get_data_parser(self): + feature_extractor = self.get_feature_extractor() + return MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, + target_channels=self.get_target_channels(), + expected_hidden_size=self._get_expected_hidden_size(), + ) + + +class MingFlashOmniThinkerDummyInputsBuilder(BaseDummyInputsBuilder[MingFlashOmniThinkerProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + num_audios = mm_counts.get("audio", 0) + + hf_processor = self.info.get_hf_processor() + + audio_token: str = hf_processor.audio_token + image_token: str = hf_processor.image_token + video_token: str = hf_processor.video_token + + return image_token * num_images + video_token * num_videos + audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + num_audios = mm_counts.get("audio", 0) + + # Default dimensions for dummy data + image_width, image_height = 448, 448 + video_width, video_height = 448, 448 + num_frames = 8 + audio_duration = 3.0 # seconds + sample_rate = 16000 + + audio_length = int(audio_duration * sample_rate) + + mm_data: MultiModalDataDict = { + "image": self._get_dummy_images( + width=image_width, + height=image_height, + num_images=num_images, + ), + "video": self._get_dummy_videos( + width=video_width, + height=video_height, + num_frames=num_frames, + num_videos=num_videos, + ), + "audio": [(np.random.randn(audio_length).astype(np.float32), sample_rate) for _ in range(num_audios)], + } + + return mm_data + + +class MingFlashOmniThinkerMultiModalProcessor(BaseMultiModalProcessor[MingFlashOmniThinkerProcessingInfo]): + """Multimodal processor for Ming-flash-omni Thinker stage. + + Handles preprocessing of 1) image, 2) video, and 3) audio inputs, + and expands placeholder tokens to the correct number of patch tokens. + """ + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + # might want to add a fallback to resolve token ids + # vocab = tokenizer.get_vocab() + thinker_config = self.info.get_hf_config() + + # patch/delimiter token IDs (used in replacement sequences) + image_start_token_id = thinker_config.llm_config.image_start_token + image_patch_token_id = thinker_config.llm_config.image_patch_token + image_end_token_id = thinker_config.llm_config.image_end_token + + video_start_token_id = thinker_config.llm_config.video_start_token + frame_patch_token_id = thinker_config.llm_config.video_patch_token + video_end_token_id = thinker_config.llm_config.video_end_token + + audio_start_token_id = thinker_config.llm_config.audio_start_token + audio_patch_token_id = thinker_config.llm_config.audio_patch_token + audio_end_token_id = thinker_config.llm_config.audio_end_token + + vision_config = thinker_config.vision_config + spatial_merge_size = vision_config.spatial_merge_size if vision_config else 2 + + newline_token_ids: list[int] = tokenizer.encode("\n", add_special_tokens=False) + + out_mm_data = out_mm_kwargs.get_data() + + def get_replacement_image(item_idx: int) -> PromptUpdateDetails: + """Generate token sequence for an image.""" + grid_thw = out_mm_data.get("image_grid_thw") + if grid_thw is None: + raise ValueError( + "image_grid_thw missing from processor output; " + "cannot determine image patch count for prompt replacement." + ) + if isinstance(grid_thw, torch.Tensor): + thw = grid_thw[item_idx] + num_patches = int(thw.prod().item()) // (spatial_merge_size**2) + else: + thw = grid_thw[item_idx] + num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2) + + # Build token sequence: *N \n + # the newline token is added in purpose from original model processing + tokens: list[int] = [] + tokens.append(image_start_token_id) + tokens.extend([image_patch_token_id] * num_patches) + tokens.append(image_end_token_id) + # Refer to Ming's BailingMM2Processor._expand_image_tokens + # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/processing_bailingmm2.py + tokens.extend(newline_token_ids) + + # Only tokens receive multimodal embeddings + return PromptUpdateDetails.select_token_id(tokens, image_patch_token_id) + + def get_replacement_video(item_idx: int) -> PromptUpdateDetails: + """Generate token sequence for a video.""" + grid_thw = out_mm_data.get("video_grid_thw", None) + if grid_thw is None: + raise ValueError( + "video_grid_thw missing from processor output; " + "cannot determine video patch count for prompt replacement." + ) + if isinstance(grid_thw, torch.Tensor): + thw = grid_thw[item_idx] + num_patches = int(thw.prod().item()) // (spatial_merge_size**2) + else: + thw = grid_thw[item_idx] + num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2) + + # Build token sequence: \n + # the newline token is added in purpose from original model processing + tokens: list[int] = [] + tokens.append(video_start_token_id) + tokens.extend([frame_patch_token_id] * num_patches) + tokens.append(video_end_token_id) + tokens.extend(newline_token_ids) + + # Only tokens receive multimodal embeddings + return PromptUpdateDetails.select_token_id(tokens, frame_patch_token_id) + + def get_replacement_audio(item_idx: int) -> PromptUpdateDetails: + """Generate token sequence for an audio.""" + encoder_feats_lengths = out_mm_data.get("encoder_feats_lengths", None) + if encoder_feats_lengths is None: + raise ValueError( + "encoder_feats_lengths missing from processor output; " + "cannot determine audio patch count for prompt replacement." + ) + if isinstance(encoder_feats_lengths, torch.Tensor): + num_patches = int(encoder_feats_lengths[item_idx].item()) + else: + num_patches = encoder_feats_lengths[item_idx] + + # Build token sequence: + tokens: list[int] = [] + tokens.append(audio_start_token_id) + tokens.extend([audio_patch_token_id] * num_patches) + tokens.append(audio_end_token_id) + + # Only tokens receive multimodal embeddings + return PromptUpdateDetails.select_token_id(tokens, audio_patch_token_id) + + # Build prompt updates and process replacement + updates: list[PromptUpdate] = [] + + if "image" in mm_items and mm_items.get_items("image", ImageProcessorItems): + updates.append( + PromptReplacement( + modality="image", + target=PLACEHOLDER_IMAGE_TOKEN_IN_TEXT, + replacement=get_replacement_image, + ) + ) + if "video" in mm_items and mm_items.get_items("video", VideoProcessorItems): + updates.append( + PromptReplacement( + modality="video", + target=PLACEHOLDER_VIDEO_TOKEN_IN_TEXT, + replacement=get_replacement_video, + ) + ) + if "audio" in mm_items and mm_items.get_items("audio", AudioProcessorItems): + updates.append( + PromptReplacement( + modality="audio", + target=PLACEHOLDER_AUDIO_TOKEN_IN_TEXT, + replacement=get_replacement_audio, + ) + ) + return updates + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + config: dict[str, MultiModalFieldConfig] = {} + + # Image fields, pixel_values is flat (concatenated patches from all images) + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + if "pixel_values" in hf_inputs: + image_sizes = image_grid_thw.prod(-1) + config["pixel_values"] = MultiModalFieldConfig.flat_from_sizes( + "image", + image_sizes, + ) + if "image_grid_thw" in hf_inputs: + config["image_grid_thw"] = MultiModalFieldConfig.batched("image") + + # Video fields, same flat layout as images + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + if "pixel_values_videos" in hf_inputs: + video_sizes = video_grid_thw.prod(-1) + config["pixel_values_videos"] = MultiModalFieldConfig.flat_from_sizes( + "video", + video_sizes, + ) + if "video_grid_thw" in hf_inputs: + config["video_grid_thw"] = MultiModalFieldConfig.batched("video") + + # Audio fields + if "audio_feats" in hf_inputs: + config["audio_feats"] = MultiModalFieldConfig.batched("audio") + if "audio_feats_lengths" in hf_inputs: + config["audio_feats_lengths"] = MultiModalFieldConfig.batched("audio") + if "encoder_feats_lengths" in hf_inputs: + config["encoder_feats_lengths"] = MultiModalFieldConfig.batched("audio") + if "placeholder_audio_loc_lens" in hf_inputs: + config["placeholder_audio_loc_lens"] = MultiModalFieldConfig.batched("audio") + + return config + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + """Call sub-processors for multimodal inputs and tokenize. + + We call the image/audio sub-processors directly (instead of going + through `MingFlashOmniProcessor.__call__`) so that the high-level + placeholder tokens remain **unexpanded** in the tokenized output. + """ + hf_processor = self.info.get_hf_processor() + tokenizer = self.info.get_tokenizer() + + data: dict[str, object] = {} + + images = mm_data.get("images", None) + if images is not None: + image_outputs = hf_processor.image_processor( + images=images, + videos=None, + return_tensors="pt", + ) + data.update(image_outputs) + + videos = mm_data.get("videos", None) + if videos is not None: + video_outputs = hf_processor.image_processor( + images=None, + videos=videos, + return_tensors="pt", + ) + # Rename keys to distinguish from images + if "pixel_values" in video_outputs: + video_outputs["pixel_values_videos"] = video_outputs.pop("pixel_values") + if "image_grid_thw" in video_outputs: + video_outputs["video_grid_thw"] = video_outputs.pop("image_grid_thw") + data.update(video_outputs) + + audios = mm_data.get("audios", None) + if audios is not None: + # vLLM's AudioProcessorItems provides raw numpy arrays (already resampled). + # MingWhisperAudioProcessor expects (waveform, sr) tuples, + # so wrap them with the target sample rate. + target_sr = hf_processor.audio_processor.sampling_rate + audio_tuples = [(a, target_sr) if not isinstance(a, tuple) else a for a in audios] + + audio_outputs = hf_processor.audio_processor( + audio_tuples, + return_tensors="pt", + ) + data.update(audio_outputs) + + # Tokenize text with placeholders still intact + text_outputs = tokenizer(prompt, return_tensors="pt", **tok_kwargs) + data.update(text_outputs) + + return BatchFeature(data=data) + + +@MULTIMODAL_REGISTRY.register_processor( + MingFlashOmniThinkerMultiModalProcessor, + info=MingFlashOmniThinkerProcessingInfo, + dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder, +) +class MingFlashOmniThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsMRoPE, + CustomProcessMixin, +): + """Ming Thinker stage: multimodal understanding + (text + image + video + audio) -> text generation. + """ + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"model.": "language_model."}, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + # vllm_omni/transformers_utils/processors/ming.py + if modality.startswith("image"): + return "" + elif modality.startswith("video"): + return "