diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/config/test_pipeline_registry.py b/tests/config/test_pipeline_registry.py new file mode 100644 index 00000000000..3483d530c63 --- /dev/null +++ b/tests/config/test_pipeline_registry.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the central pipeline registry (2.5/N).""" + +from __future__ import annotations + +import pytest + +from vllm_omni.config.pipeline_registry import ( + _DIFFUSION_PIPELINES, + _OMNI_PIPELINES, + _VLLM_OMNI_PIPELINES, +) +from vllm_omni.config.stage_config import ( + _PIPELINE_REGISTRY, + PipelineConfig, + StageExecutionType, + StagePipelineConfig, + register_pipeline, +) + + +class TestCentralRegistryDeclarations: + """Every in-tree pipeline must be declared exactly once in the central registry.""" + + def test_union_contains_all_omni(self): + for key in _OMNI_PIPELINES: + assert key in _VLLM_OMNI_PIPELINES + + def test_union_contains_all_diffusion(self): + for key in _DIFFUSION_PIPELINES: + assert key in _VLLM_OMNI_PIPELINES + + def test_no_duplicate_model_type_between_omni_and_diffusion(self): + overlap = set(_OMNI_PIPELINES) & set(_DIFFUSION_PIPELINES) + assert not overlap, f"Duplicate model_types across omni/diffusion: {overlap}" + + def test_expected_omni_pipelines_present(self): + # Guard against accidental removal during future refactors. + assert "qwen2_5_omni" in _OMNI_PIPELINES + assert "qwen2_5_omni_thinker_only" in _OMNI_PIPELINES + assert "qwen3_omni_moe" in _OMNI_PIPELINES + assert "qwen3_tts" in _OMNI_PIPELINES + + +class TestLazyLoading: + """Pipelines are imported only on first access.""" + + def test_contains_without_import(self): + # ``in`` hits the lazy map, not the loaded cache. + assert "qwen3_omni_moe" in _PIPELINE_REGISTRY + + def test_getitem_loads_correct_pipeline(self): + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + assert pipeline.model_type == "qwen3_omni_moe" + assert pipeline.model_arch == "Qwen3OmniMoeForConditionalGeneration" + + def test_unknown_model_type_returns_none_via_get(self): + assert _PIPELINE_REGISTRY.get("not_a_real_pipeline") is None + + def test_unknown_model_type_raises_keyerror_via_getitem(self): + with pytest.raises(KeyError): + _PIPELINE_REGISTRY["not_a_real_pipeline"] + + def test_iteration_yields_registered_pipelines(self): + keys = set(_PIPELINE_REGISTRY) + assert "qwen2_5_omni" in keys + assert "qwen3_omni_moe" in keys + + +class TestDynamicRegistration: + """``register_pipeline()`` still works for plugins and tests.""" + + def test_register_adds_to_registry(self): + custom = PipelineConfig( + model_type="_test_dynamic_registration", + model_arch="DynamicTestModel", + stages=( + StagePipelineConfig( + stage_id=0, + model_stage="test", + execution_type=StageExecutionType.LLM_AR, + input_sources=(), + final_output=True, + ), + ), + ) + register_pipeline(custom) + try: + assert "_test_dynamic_registration" in _PIPELINE_REGISTRY + assert _PIPELINE_REGISTRY["_test_dynamic_registration"] is custom + finally: + # Don't leak the test registration into other tests. + if "_test_dynamic_registration" in _PIPELINE_REGISTRY: + del _PIPELINE_REGISTRY["_test_dynamic_registration"] + + def test_dynamic_registration_overrides_lazy_entry(self): + # Build a substitute for qwen3_omni_moe that we can distinguish. + original = _PIPELINE_REGISTRY["qwen3_omni_moe"] + override = PipelineConfig( + model_type="qwen3_omni_moe", + model_arch="OverriddenArch", + stages=original.stages, + ) + register_pipeline(override) + try: + assert _PIPELINE_REGISTRY["qwen3_omni_moe"].model_arch == "OverriddenArch" + finally: + # Remove the dynamic override so later tests see the original. + if "qwen3_omni_moe" in _PIPELINE_REGISTRY._loaded: + del _PIPELINE_REGISTRY["qwen3_omni_moe"] diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 2e2bdc75dcd..1d65d3acd27 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -23,7 +23,7 @@ register_pipeline, strip_parent_engine_args, ) -from vllm_omni.engine.arg_utils import internal_blacklist_keys +from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys class TestStageType: @@ -330,6 +330,9 @@ class TestStageResolutionHelpers: """Tests for shared stage override / filtering helpers.""" def test_build_stage_runtime_overrides_ignores_other_stage_and_internal_keys(self): + # Pass the same filter set the function uses by default + # (orchestrator-only fields plus SHARED_FIELDS so ``model`` is + # treated as not-per-stage-overridable). overrides = build_stage_runtime_overrides( 0, { @@ -339,7 +342,7 @@ def test_build_stage_runtime_overrides_ignores_other_stage_and_internal_keys(sel "stage_0_model": "should_be_ignored", "parallel_config": {"world_size": 2}, }, - internal_keys=internal_blacklist_keys(), + internal_keys=internal_blacklist_keys() | SHARED_FIELDS, ) assert overrides["gpu_memory_utilization"] == 0.9 @@ -672,19 +675,27 @@ def test_parse_missing_async_chunk_defaults_false(self, tmp_path): class TestPipelineDiscovery: - """Tests for auto-discovery of pipelines from models/*/pipeline.py.""" + """Tests for the central pipeline registry (``pipeline_registry._VLLM_OMNI_PIPELINES``).""" - def test_discover_populates_registry_with_known_models(self): - """``_discover_all_pipelines`` imports every pipeline.py so the - registry is populated with the built-in models after one call.""" - from vllm_omni.config.stage_config import _discover_all_pipelines - - _discover_all_pipelines() - # These models have a pipeline.py in-tree and must be registered. + def test_registry_has_known_models(self): + """Built-in pipelines are lazy-loaded from the central declaration + on first access; no eager import or discovery walk needed.""" + # ``in`` triggers the lazy-map lookup without forcing a load. assert "qwen2_5_omni" in _PIPELINE_REGISTRY assert "qwen3_omni_moe" in _PIPELINE_REGISTRY assert "qwen3_tts" in _PIPELINE_REGISTRY + def test_registry_loads_pipeline_on_getitem(self): + """Looking up a registered model_type returns the matching PipelineConfig.""" + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + assert pipeline.model_type == "qwen3_omni_moe" + assert len(pipeline.stages) == 3 # thinker + talker + code2wav + + def test_registry_returns_none_for_unknown(self): + """Unknown model_types aren't found; ``get()`` returns None.""" + assert "definitely_not_a_real_model" not in _PIPELINE_REGISTRY + assert _PIPELINE_REGISTRY.get("definitely_not_a_real_model") is None + def test_pipeline_config_supports_hf_architectures(self): """PipelineConfig accepts hf_architectures for HF-arch fallback (replaces the old _ARCHITECTURE_MODELS dict).""" @@ -950,7 +961,10 @@ def test_ci_inherits_from_main(self): assert deploy.stages[0].gpu_memory_utilization == 0.9 assert deploy.connectors is not None assert "connector_of_shared_memory" in deploy.connectors - assert deploy.async_chunk is True + # CI overlay explicitly sets async_chunk: False (see + # tests/utils.py::_CI_OVERLAYS and PR #2383 discussion). Overlay + # bool overrides base even when the base yaml has async_chunk: true. + assert deploy.async_chunk is False def test_ci_sampling_merge(self): from tests.utils import get_deploy_config_path diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py new file mode 100644 index 00000000000..c07bc2610c3 --- /dev/null +++ b/vllm_omni/config/pipeline_registry.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Central declarative registry of all vllm-omni pipelines. + +Mirrors the pattern in ``vllm/model_executor/models/registry.py``: each entry +is ``model_type -> (module_path, variable_name)``, and the module is imported +lazily on first lookup (see ``_LazyPipelineRegistry`` in +``vllm_omni/config/stage_config.py``). Keeping every pipeline declared in one +file makes it easy to spot a missing registration, which was the original +motivation in https://github.com/vllm-project/vllm-omni/issues/2887 (item 4). + +Per-model ``pipeline.py`` modules still define the ``PipelineConfig`` instance; +they just no longer need to self-register via ``register_pipeline(...)``. + +Adding a new pipeline: + 1. Define the ``PipelineConfig`` instance as a module-level variable in + ``vllm_omni/.../pipeline.py``. + 2. Add one line to ``_OMNI_PIPELINES`` or ``_DIFFUSION_PIPELINES`` below. + +``register_pipeline(config)`` in ``stage_config`` is still supported for +out-of-tree plugins and tests that create pipelines at runtime; those override +the entries declared here. +""" + +from __future__ import annotations + +# --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) --- +_OMNI_PIPELINES: dict[str, tuple[str, str]] = { + # model_type -> (module_path, variable_name) + "qwen2_5_omni": ( + "vllm_omni.model_executor.models.qwen2_5_omni.pipeline", + "QWEN2_5_OMNI_PIPELINE", + ), + "qwen2_5_omni_thinker_only": ( + "vllm_omni.model_executor.models.qwen2_5_omni.pipeline", + "QWEN2_5_OMNI_THINKER_ONLY_PIPELINE", + ), + "qwen3_omni_moe": ( + "vllm_omni.model_executor.models.qwen3_omni.pipeline", + "QWEN3_OMNI_PIPELINE", + ), + "qwen3_tts": ( + "vllm_omni.model_executor.models.qwen3_tts.pipeline", + "QWEN3_TTS_PIPELINE", + ), +} + +# --- Single-stage diffusion pipelines (populated in PR 3/N) --- +_DIFFUSION_PIPELINES: dict[str, tuple[str, str]] = {} + +# Union view used by ``_LazyPipelineRegistry``; don't mutate at runtime. +_VLLM_OMNI_PIPELINES: dict[str, tuple[str, str]] = { + **_OMNI_PIPELINES, + **_DIFFUSION_PIPELINES, +} diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index fc14283630b..392a550be68 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -39,15 +39,18 @@ def build_stage_runtime_overrides( ) -> dict[str, Any]: """Build per-stage runtime overrides from global and ``stage__*`` kwargs. - ``internal_keys`` defaults to the set derived from ``OrchestratorArgs`` - (via ``arg_utils.internal_blacklist_keys``) so that orchestrator - fields are never forwarded as per-stage engine args. Callers can pass an - explicit set for tests or specialized flows. + ``internal_keys`` defaults to the union of + ``arg_utils.internal_blacklist_keys()`` and ``arg_utils.SHARED_FIELDS`` + so that neither orchestrator-only fields nor shared-pipeline fields + (``model`` / ``stage_configs_path`` / ``log_stats`` / ``stage_id``) leak + into a stage's per-stage runtime overrides — the orchestrator sets those + uniformly for every stage, they are not per-stage knobs. Callers can + pass an explicit set for tests or specialized flows. """ if internal_keys is None: - from vllm_omni.engine.arg_utils import internal_blacklist_keys + from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys - internal_keys = internal_blacklist_keys() + internal_keys = internal_blacklist_keys() | SHARED_FIELDS result: dict[str, Any] = {} @@ -223,11 +226,140 @@ def validate(self) -> list[str]: return errors -_PIPELINE_REGISTRY: dict[str, PipelineConfig] = {} +class _LazyPipelineRegistry: + """Dict-like registry that lazy-loads pipelines from the central declaration. + + In-tree pipelines are declared once in + ``vllm_omni/config/pipeline_registry.py`` as + ``model_type -> (module_path, variable_name)`` entries; the module is + imported only when the pipeline is first looked up. This mirrors the + pattern in ``vllm/model_executor/models/registry.py`` and addresses + https://github.com/vllm-project/vllm-omni/issues/2887 (item 4): having + every registration in one file makes a missing entry easy to spot. + + Out-of-tree / dynamic registrations via ``register_pipeline()`` are stored + directly in ``_loaded`` and take precedence over the lazy-map entry with + the same ``model_type``. + + The class exposes the subset of ``dict`` operations the rest of this + module relies on (``__contains__``, ``__getitem__``, ``__setitem__``, + ``get``, ``keys``, ``values``, ``items``, ``__iter__``), so existing call + sites don't need to change. + """ + + def __init__(self) -> None: + self._loaded: dict[str, PipelineConfig] = {} + # Populated lazily to avoid a circular import at module init time. + self._lazy_map: dict[str, tuple[str, str]] | None = None + + def _get_lazy_map(self) -> dict[str, tuple[str, str]]: + if self._lazy_map is None: + from vllm_omni.config.pipeline_registry import _VLLM_OMNI_PIPELINES + + self._lazy_map = _VLLM_OMNI_PIPELINES + return self._lazy_map + + def _load_lazy(self, model_type: str) -> PipelineConfig | None: + entry = self._get_lazy_map().get(model_type) + if entry is None: + return None + module_path, var_name = entry + import importlib + + try: + module = importlib.import_module(module_path) + except ImportError as exc: + logger.error( + "Failed to import pipeline module %r for %r: %s", + module_path, + model_type, + exc, + ) + return None + pipeline = getattr(module, var_name, None) + if pipeline is None: + logger.error( + "Pipeline variable %r not found in module %r (registered for %r)", + var_name, + module_path, + model_type, + ) + return None + errors = pipeline.validate() + if errors: + logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors) + self._loaded[model_type] = pipeline + return pipeline + + def __contains__(self, model_type: str) -> bool: + if model_type in self._loaded: + return True + return model_type in self._get_lazy_map() + + def __getitem__(self, model_type: str) -> PipelineConfig: + if model_type in self._loaded: + return self._loaded[model_type] + pipeline = self._load_lazy(model_type) + if pipeline is None: + raise KeyError(model_type) + return pipeline + + def get(self, model_type: str, default: PipelineConfig | None = None) -> PipelineConfig | None: + if model_type in self._loaded: + return self._loaded[model_type] + pipeline = self._load_lazy(model_type) + return pipeline if pipeline is not None else default + + def __setitem__(self, model_type: str, pipeline: PipelineConfig) -> None: + self._loaded[model_type] = pipeline + + def __delitem__(self, model_type: str) -> None: + """Remove a dynamically-registered pipeline. + + Only the dynamic-cache side of the registry can be mutated; the + central declarative registry is immutable at runtime. Calling ``del`` + on a model_type that only exists in the central registry raises + ``KeyError``. + """ + if model_type in self._loaded: + del self._loaded[model_type] + return + if model_type in self._get_lazy_map(): + raise KeyError( + f"{model_type!r} is declared in the central pipeline_registry and " + "cannot be removed at runtime. Edit " + "vllm_omni/config/pipeline_registry.py to delete a built-in entry." + ) + raise KeyError(model_type) + + def keys(self) -> set[str]: + return set(self._get_lazy_map().keys()) | set(self._loaded.keys()) + + def values(self): + # Iterating values forces load of every lazy pipeline. + for key in self.keys(): + yield self[key] + + def items(self): + for key in self.keys(): + yield key, self[key] + + def __iter__(self): + return iter(self.keys()) + + +_PIPELINE_REGISTRY = _LazyPipelineRegistry() def register_pipeline(pipeline: PipelineConfig) -> None: - """Register a pipeline config (called at import time by pipeline.py modules).""" + """Register a pipeline config dynamically. + + In-tree pipelines are declared in ``pipeline_registry._VLLM_OMNI_PIPELINES`` + and loaded lazily; calling ``register_pipeline`` is only needed for + out-of-tree plugins or tests that build a ``PipelineConfig`` at runtime. + A dynamic registration overrides the central-registry entry with the same + ``model_type``. + """ errors = pipeline.validate() if errors: logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors) @@ -826,41 +958,17 @@ def validate_pipeline(self) -> list[str]: return errors -def _discover_all_pipelines() -> None: - """Import every ``models//pipeline.py`` once to populate the registry. - - Each pipeline.py is expected to call ``register_pipeline(PipelineConfig(...))`` - at import time. This function walks the models directory and imports any - pipeline.py it finds — contributors only need to drop a new pipeline.py - in their model's directory for the factory to pick it up. - - Idempotent: Python's module cache ensures subsequent calls are no-ops. - """ - if not _MODELS_DIR.exists(): - return - for subdir in sorted(_MODELS_DIR.iterdir()): - if not subdir.is_dir(): - continue - if not (subdir / "pipeline.py").exists(): - continue - module_path = f"vllm_omni.model_executor.models.{subdir.name}.pipeline" - try: - __import__(module_path) - except Exception as exc: - logger.debug("Skipping pipeline module %s: %s", module_path, exc) - - class StageConfigFactory: """Factory that loads pipeline YAML and merges CLI overrides. Handles both single-stage and multi-stage models. - Pipelines are auto-discovered from ``models//pipeline.py`` modules; - no hardcoded model-type → directory mapping is maintained here. Models - with generic HF ``model_type`` collisions (e.g. MiMo Audio reports - ``qwen2``) should declare ``hf_architectures=(...)`` on their - ``PipelineConfig`` so the factory can disambiguate via - ``hf_config.architectures``. + Pipelines are declared in ``vllm_omni/config/pipeline_registry.py`` and + loaded lazily via ``_PIPELINE_REGISTRY``; no hardcoded model-type → + directory mapping is maintained here. Models with generic HF + ``model_type`` collisions (e.g. MiMo Audio reports ``qwen2``) should + declare ``hf_architectures=(...)`` on their ``PipelineConfig`` so the + factory can disambiguate via ``hf_config.architectures``. """ @classmethod @@ -885,9 +993,6 @@ def create_from_model( trust_remote_code = cli_overrides.get("trust_remote_code", True) - # Ensure every pipeline.py has been imported so the registry is populated. - _discover_all_pipelines() - # --- New path: check pipeline registry by model_type first --- model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code) if model_type and model_type in _PIPELINE_REGISTRY: diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py index 2a9b247a1d6..b44d08eb32a 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py @@ -11,7 +11,6 @@ PipelineConfig, StageExecutionType, StagePipelineConfig, - register_pipeline, ) _PROC = "vllm_omni.model_executor.stage_input_processors.qwen2_5_omni" @@ -57,8 +56,6 @@ ), ) -register_pipeline(QWEN2_5_OMNI_PIPELINE) - # Single-stage thinker-only variant for the abort test. QWEN2_5_OMNI_THINKER_ONLY_PIPELINE = PipelineConfig( @@ -79,5 +76,3 @@ ), ), ) - -register_pipeline(QWEN2_5_OMNI_THINKER_ONLY_PIPELINE) diff --git a/vllm_omni/model_executor/models/qwen3_omni/pipeline.py b/vllm_omni/model_executor/models/qwen3_omni/pipeline.py index fcaa7ba0284..1c69ec79570 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/pipeline.py +++ b/vllm_omni/model_executor/models/qwen3_omni/pipeline.py @@ -11,7 +11,6 @@ PipelineConfig, StageExecutionType, StagePipelineConfig, - register_pipeline, ) _PROC = "vllm_omni.model_executor.stage_input_processors.qwen3_omni" @@ -62,5 +61,3 @@ ), ), ) - -register_pipeline(QWEN3_OMNI_PIPELINE) diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py index 6c9ed447853..5051715ceac 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py +++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py @@ -9,7 +9,6 @@ PipelineConfig, StageExecutionType, StagePipelineConfig, - register_pipeline, ) _PROC = "vllm_omni.model_executor.stage_input_processors.qwen3_tts" @@ -47,5 +46,3 @@ ), ), ) - -register_pipeline(QWEN3_TTS_PIPELINE)