diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml index 6f206422a4b..7b9be0993dc 100644 --- a/.buildkite/test-nightly.yml +++ b/.buildkite/test-nightly.yml @@ -23,7 +23,7 @@ steps: - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT resources: limits: - nvidia.com/gpu: 2 + nvidia.com/gpu: 3 volumeMounts: - name: devshm mountPath: /dev/shm diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml index 2d46c753a92..73163ae2f74 100644 --- a/.buildkite/test-ready.yml +++ b/.buildkite/test-ready.yml @@ -210,7 +210,7 @@ steps: - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT resources: limits: - nvidia.com/gpu: 2 + nvidia.com/gpu: 3 volumeMounts: - name: devshm mountPath: /dev/shm diff --git a/docs/configuration/pd_disaggregation.md b/docs/configuration/pd_disaggregation.md index 9196bdb0240..b65d27dfff2 100644 --- a/docs/configuration/pd_disaggregation.md +++ b/docs/configuration/pd_disaggregation.md @@ -3,162 +3,200 @@ PD disaggregation splits the Qwen3-Omni thinker into separate prefill and decode stages so prompt processing and token generation can run on different workers. -This is documented as a stage-config recipe instead of a bundled YAML because the -deployment-specific values usually change per environment: - -- GPU placement -- `tensor_parallel_size` -- connector backend and connector ports -- connector IPs or bootstrap addresses - -Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml) -and copy it to your own file, for example `qwen3_omni_pd.yaml`. Then apply the -changes below. +After the config refactor, PD is no longer launched from a separate legacy +`stage_configs/*.yaml` file. Instead, it is enabled from the deploy config via +the `pd_disaggregation` section in +[`vllm_omni/deploy/qwen3_omni_moe.yaml`](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml). + +## Current Config-Based Flow + +At runtime, the config system does the following when +`pd_disaggregation.enabled: true`: + +1. Load the normal 3-stage Qwen3-Omni pipeline + deploy config. +2. Dynamically split the thinker into: + - stage `0`: thinker prefill + - stage `1`: thinker decode +3. Shift downstream stages by one index: + - talker: `1 -> 2` + - code2wav: `2 -> 3` +4. Inject `is_prefill_only`, `is_decode_only`, and `kv_transfer_config` into the + resolved runtime stage configs. +5. Reuse the existing PD detection / routing logic in the engine. + +So the user-facing deploy file stays single-source, but the resolved runtime +config becomes a 4-stage PD pipeline. ## Requirements -- 3+ GPUs for a basic layout: prefill, decode, and talker+code2wav +- 3+ GPUs for the common layout: + - prefill on GPU `0` + - decode on GPU `1` + - talker + code2wav on GPU `2` - A KV connector supported by vLLM, such as `MooncakeConnector` - Matching `tensor_parallel_size` on the prefill and decode thinker stages -## 1. Split the thinker into prefill and decode stages +## How to Enable PD + +PD is enabled from the existing bundled deploy config: + +- `vllm_omni/deploy/qwen3_omni_moe.yaml` -Replace the original thinker stage with two stages: +No additional user-facing YAML is required. The intent of the config refactor +is to keep Qwen3-Omni on a single deploy config and switch PD on through the +`pd_disaggregation` section in that file. + +Edit `vllm_omni/deploy/qwen3_omni_moe.yaml` and enable / tune: ```yaml -stage_args: - - stage_id: 0 - stage_type: llm - is_prefill_only: true - runtime: - devices: "0" - engine_args: +pd_disaggregation: + enabled: true + async_chunk: false + target_stage_id: 0 + stages: + - role: prefill max_num_seqs: 16 - model_stage: thinker - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 - enforce_eager: true - trust_remote_code: true - engine_output_type: latent - distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 - hf_config_name: thinker_config + devices: "0" tensor_parallel_size: 1 - kv_transfer_config: - kv_connector: "MooncakeConnector" - kv_role: "kv_producer" - kv_rank: 0 - kv_parallel_size: 2 - kv_connector_extra_config: - mooncake_bootstrap_port: 25201 - final_output: false - is_comprehension: true - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - - - stage_id: 1 - stage_type: llm - is_decode_only: true - runtime: - devices: "1" - engine_args: + engine_extras: + kv_transfer_config: + kv_connector: "MooncakeConnector" + kv_role: "kv_producer" + kv_rank: 0 + kv_parallel_size: 2 + kv_connector_extra_config: + mooncake_bootstrap_port: 25201 + - role: decode max_num_seqs: 64 - model_stage: thinker - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 - enforce_eager: true - trust_remote_code: true - engine_output_type: latent - distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 - hf_config_name: thinker_config + devices: "1" tensor_parallel_size: 1 - kv_transfer_config: - kv_connector: "MooncakeConnector" - kv_role: "kv_consumer" - kv_rank: 1 - kv_parallel_size: 2 - kv_connector_extra_config: - mooncake_bootstrap_port: 25202 - engine_input_source: [0] - final_output: true - final_output_type: text - is_comprehension: true - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 + engine_extras: + kv_transfer_config: + kv_connector: "MooncakeConnector" + kv_role: "kv_consumer" + kv_rank: 1 + kv_parallel_size: 2 + kv_connector_extra_config: + mooncake_bootstrap_port: 25202 + stage_overrides: + - stage_id: 1 + devices: "2" + - stage_id: 2 + devices: "2" ``` Notes: -- `is_prefill_only: true` marks the thinker stage that only saves KV. -- `is_decode_only: true` marks the thinker stage that resumes from remote KV. -- `kv_transfer_config` is required on both stages. -- The orchestrator forces the prefill stage to run with `max_tokens=1`, so the - prefill side only processes the prompt and exports KV. +- `target_stage_id: 0` means the original thinker is the stage being split. +- `async_chunk: false` matches the current PD path. +- The `pd_disaggregation.stage_overrides` block keeps the common 3-GPU layout: + - original talker (`stage_id: 1`) stays on GPU `2` + - original code2wav (`stage_id: 2`) stays on GPU `2` +- After PD expansion, these become runtime stage `2` and stage `3`. + +## Launching with Config-Based PD -## 2. Shift the downstream stages by one index +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \ + --deploy-config vllm_omni/deploy/qwen3_omni_moe.yaml +``` -After inserting the extra thinker stage, renumber the remaining stages: +If you edit the bundled deploy file in place, the explicit `--deploy-config` +flag is optional as long as the runtime resolves the default deploy config for +the model. -```yaml - - stage_id: 2 - runtime: - devices: "2" - engine_input_source: [1] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker +You can also enable PD from CLI without editing the YAML: - - stage_id: 3 - runtime: - devices: "2" - engine_args: - max_num_seqs: 1 - engine_input_source: [2] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \ + --deploy-config vllm_omni/deploy/qwen3_omni_moe.yaml \ + --enable-pd-disaggregation ``` -Compared with the default Qwen3-Omni config: +`--enable-pd-disaggregation` overrides the deploy YAML's +`pd_disaggregation.enabled` value for that launch only. Because the CLI flag is +declared with `argparse.BooleanOptionalAction`, both forms are supported: -- the talker becomes stage `2` instead of stage `1` -- the code2wav stage becomes stage `3` instead of stage `2` -- the talker now reads from decode stage `1` +- `--enable-pd-disaggregation`: force PD on for this run +- `--no-enable-pd-disaggregation`: force PD off for this run -## 3. Add runtime edges for the four-stage pipeline +When the flag is omitted, the YAML value stays in effect. -```yaml -runtime: - enabled: true - edges: - - from: 0 - to: 1 - - from: 1 - to: 2 - - from: 2 - to: 3 +To tune the generated prefill/decode runtime stages from CLI, reuse +`--stage-overrides` after PD is enabled. In the resolved 4-stage runtime config, +stage `0` is prefill and stage `1` is decode: + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \ + --deploy-config vllm_omni/deploy/qwen3_omni_moe.yaml \ + --enable-pd-disaggregation \ + --stage-overrides '{"0":{"max_num_seqs":8},"1":{"max_num_seqs":32}}' ``` -## 4. Launch with your custom config +## Tests + +At the moment, the PD-aware tests are these three files: + +- `tests/e2e/online_serving/test_qwen3_omni.py` +- `tests/e2e/online_serving/test_qwen3_omni_expansion.py` +- `tests/entrypoints/test_pd_disaggregation.py` + +### 1. Online serving E2E + +Both online-serving test files include the regular 2-GPU cases and the PD +3-GPU case in the same parametrized suite. The Qwen3-Omni coverage currently +uses these modes: + +- `default`: non-PD, 2-GPU layout +- `async_chunk`: non-PD async-chunk path, 2-GPU layout +- `pd_default`: PD disaggregation, 3-GPU layout + +No `VLLM_TEST_PD_MODE` environment variable is needed. The tests select the +desired mode directly from the parametrized config path, and all online-serving +Qwen3-Omni cases launch through the stage CLI harness (`use_stage_cli=True`). + +Run `test_qwen3_omni.py`: ```bash -vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \ - --stage-configs-path /path/to/qwen3_omni_pd.yaml +pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py \ + -m "advanced_model" --run-level "advanced_model" +``` + +Run `test_qwen3_omni_expansion.py`: + +```bash +pytest -s -v tests/e2e/online_serving/test_qwen3_omni_expansion.py \ + -m "advanced_model" --run-level "advanced_model" +``` + +Run a single expansion case, for example: + +```bash +pytest -s -v tests/e2e/online_serving/test_qwen3_omni_expansion.py \ + -k "test_audio_in_video_002" \ + -m "advanced_model" --run-level "advanced_model" +``` + +### 2. PD unit / entrypoint coverage + +`test_pd_disaggregation.py` does not require the old PD YAML anymore. It builds +a temporary deploy overlay inside the test process only, enables +`pd_disaggregation`, then verifies that the merged runtime config becomes a valid +4-stage PD pipeline. This temporary file is a test helper, not a user-facing +config artifact. + +```bash +pytest tests/entrypoints/test_pd_disaggregation.py -q +``` + +### 3. DFX perf benchmark + +To run the Qwen3-Omni PD performance benchmark config added under +`tests/dfx/perf/tests/`, use: + +```bash +pytest tests/dfx/perf/scripts/run_benchmark.py \ + --test-config-file tests/dfx/perf/tests/test_qwen_omni_pd.json -s ``` ## Operational Notes diff --git a/tests/dfx/conftest.py b/tests/dfx/conftest.py index 0c59bed994f..8f06d28eb3b 100644 --- a/tests/dfx/conftest.py +++ b/tests/dfx/conftest.py @@ -11,7 +11,7 @@ from tests.dfx.reliability.helpers import list_remote_process_pids_by_pattern, post_chat_completions_raw from tests.helpers.runtime import OmniServerParams -from tests.helpers.stage_config import modify_stage_config +from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config from vllm_omni.platforms import current_omni_platform @@ -87,7 +87,16 @@ def create_unique_server_params( model = server_params["model"] stage_config_name = server_params.get("stage_config_name") if stage_config_name: - stage_config_path = str(stage_configs_dir / stage_config_name) + stage_config_path = None + raw_stage_config = Path(stage_config_name) + if raw_stage_config.is_absolute(): + stage_config_path = str(raw_stage_config) + else: + local_stage_config = stage_configs_dir / stage_config_name + if local_stage_config.exists(): + stage_config_path = str(local_stage_config) + else: + stage_config_path = get_deploy_config_path(stage_config_name) delete = server_params.get("delete", None) update = server_params.get("update", None) stage_config_path = modify_stage(stage_config_path, update, delete) diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py index c1f3264c18e..3adc6a6c5c9 100644 --- a/tests/dfx/perf/scripts/run_benchmark.py +++ b/tests/dfx/perf/scripts/run_benchmark.py @@ -43,7 +43,7 @@ def _get_config_file_from_argv() -> str | None: if CONFIG_FILE_PATH is None: print( "No --test-config-file in argv, using default: tests/dfx/perf/tests/test_qwen_omni.json " - "(override with e.g. --test-config-file tests/dfx/perf/tests/test_tts.json)" + "(override with e.g. --test-config-file tests/dfx/perf/tests/test_qwen_omni_pd.json)" ) CONFIG_FILE_PATH = _DEFAULT_CONFIG_FILE diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index f1a3ce3535c..048f59c7037 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -6,11 +6,10 @@ import pytest -from tests.helpers.mark import hardware_test +from tests.helpers.mark import hardware_marks, hardware_test from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config -from vllm_omni.platforms import current_omni_platform os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" @@ -18,8 +17,6 @@ models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] -# Set VLLM_TEST_PD_MODE=1 to test PD disaggregation (follow-up — deploy overlay not yet migrated). -_USE_PD = os.environ.get("VLLM_TEST_PD_MODE", "0") == "1" _CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml") @@ -37,36 +34,75 @@ def get_chunk_config(config_path: str | None = None): return modify_stage_config(config_path, updates={"async_chunk": True}) -# Platform-specific overrides live inside the new deploy yaml's ``platforms:`` -# section, so a single ``_CI_DEPLOY`` path serves CUDA, ROCm, and XPU. -# TODO: re-add VLLM_TEST_PD_MODE branch once the PD-disaggregation deploy -# overlay has been migrated to the new schema (previously used the deleted -# ``qwen3_omni_moe_pd_ci.yaml`` stage-configs file). -if current_omni_platform.is_xpu(): - stage_configs = [_CI_DEPLOY] -else: # CUDA + ROCm MI325 share the same deploy config - stage_configs = [get_chunk_config()] +def get_pd_config(config_path: str | None = None): + """Load the qwen3_omni CI deploy yaml with PD disaggregation enabled.""" + if config_path is None: + config_path = _CI_DEPLOY + return modify_stage_config( + config_path, + updates={ + "pd_disaggregation.enabled": True, + "pd_disaggregation.async_chunk": False, + "stages": { + 1: {"devices": "2"}, + 2: {"devices": "2"}, + }, + }, + ) -# Create parameter combinations for model and stage config -test_params = [ - OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs -] -# For prefix caching, we enable it on the thinker (stage 0) via CLI override -# and enable prompt token details so that we can determine if any tokens were cached. -BLOCK_SIZE = 16 + +def get_prefix_caching_config(config_path: str): + """Create a stage config with prefix caching enabled on the thinker (stage 0).""" + path = modify_stage_config( + config_path, + updates={ + "stage_args": { + 0: {"engine_args.enable_prefix_caching": True}, + }, + }, + ) + return path + + +# Cover sync, async-chunk, and PD launch paths by default. +test_params = ( + [ + pytest.param( + OmniServerParams(model=model, stage_config_path=_CI_DEPLOY), + id="default", + marks=hardware_marks(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2), + ) + for model in models + ] + + [ + pytest.param( + OmniServerParams(model=model, stage_config_path=get_chunk_config()), + id="async_chunk", + marks=hardware_marks(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2), + ) + for model in models + ] + + [ + pytest.param( + OmniServerParams(model=model, stage_config_path=get_pd_config()), + id="pd_default", + marks=hardware_marks(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3), + ) + for model in models + ] +) +prefix_caching_stage_configs = [get_prefix_caching_config(_CI_DEPLOY)] + +# For prefix caching, we need to enable prompt token details so that we +# can determine if any tokens were cached. prefix_test_params = [ OmniServerParams( model=model, - stage_config_path=_CI_DEPLOY, - server_args=[ - "--block-size", - str(BLOCK_SIZE), - "--stage-overrides", - '{"0": {"enable_prefix_caching": true}}', - "--enable-prompt-tokens-details", - ], + stage_config_path=stage_config, + server_args=["--enable-prompt-tokens-details"], # Enable prompt tokens details to get cached_tokens ) for model in models + for stage_config in prefix_caching_stage_configs ] @@ -103,8 +139,7 @@ def get_max_batch_size(size_type="few"): @pytest.mark.advanced_model @pytest.mark.core_model @pytest.mark.omni -@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.") -@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2) +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) @pytest.mark.parametrize("omni_server", test_params, indirect=True) def test_mix_to_text_audio_001(omni_server, openai_client) -> None: """ @@ -143,8 +178,7 @@ def test_mix_to_text_audio_001(omni_server, openai_client) -> None: @pytest.mark.advanced_model @pytest.mark.core_model @pytest.mark.omni -@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.") -@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2) +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) @pytest.mark.parametrize("omni_server", test_params, indirect=True) def test_text_to_text_001(omni_server, openai_client) -> None: """ @@ -172,17 +206,15 @@ def test_text_to_text_001(omni_server, openai_client) -> None: @pytest.mark.omni @hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) @pytest.mark.parametrize("omni_server", prefix_test_params, indirect=True) -def test_thinker_prefix_caching(omni_server, openai_client, run_level) -> None: +def test_thinker_prefix_caching(omni_server, openai_client) -> None: """ Test thinker prefix caching by sending identical requests with an image (i.e., a large shared prefix) and verifying that the second request uses cached tokens & produces the same output with greedy decoding. - NOTE: The reason that we check against logprobs instead of direct text here is that - the outputs may still diverge a bit even though we set the seed and temperature. - This is mostly because the GEMM algorithm may vary based on the input tensors dims. - Because of this, we don't check the logprobs if it's a dummy load, since in that case - the top logprobs will all be very close. + NOTE: The seed for this test is used as a regression test for the issue linked below; + https://github.com/vllm-project/vllm-omni/issues/2833; without passing the sampling + params, this test will fail with the current default stage configs. """ seed = 10 img_res = generate_synthetic_image(224, 224, seed=seed) @@ -193,41 +225,19 @@ def test_thinker_prefix_caching(omni_server, openai_client, run_level) -> None: content_text=get_prompt("text_image"), ) - top_k = 10 - sampling_params = {"seed": seed, "temperature": 0, "max_tokens": 8, "logprobs": top_k} request_config = { "model": omni_server.model, "messages": messages, "stream": False, "modalities": ["text"], - "logprobs": True, - "top_logprobs": top_k, - "sampling_params_list": [sampling_params] * 3, + "sampling_params_list": [{"seed": seed, "temperature": 0, "max_tokens": 16}] * 3, } - uncached_response = openai_client.send_omni_request(request_config, request_num=1)[0] - cached_response = openai_client.send_omni_request(request_config, request_num=1)[0] - - # Ensure that we have a prefix cache hit on the second request and that only the last - # partial block is uncached (since currently we don't cache partial blocks). - num_cached_tokens = cached_response.cached_tokens - num_prompt_tokens = cached_response.prompt_tokens - assert num_cached_tokens is not None and num_prompt_tokens is not None - num_uncached_tokens = num_prompt_tokens % BLOCK_SIZE - assert num_cached_tokens > 0 - assert num_cached_tokens % BLOCK_SIZE == 0 - assert (num_cached_tokens + num_uncached_tokens) == num_prompt_tokens - - # Ensure that we have logprobs and tokens were generated for both requests - assert uncached_response.logprobs is not None - assert cached_response.logprobs is not None - n_tokens = min(len(uncached_response.logprobs), len(cached_response.logprobs)) - assert n_tokens > 0 - - if run_level == "advanced_model": - # For each token index where both responses have an output, ensure that the greedy token - # predicted in the uncached case is in the top k logprobs for the cached case - for idx in range(n_tokens): - greedy_token = uncached_response.logprobs[idx].token - cached_top_k = {lp.token for lp in cached_response.logprobs[idx].top_logprobs} - assert greedy_token in cached_top_k + response_1 = openai_client.send_omni_request(request_config, request_num=1)[0] + response_2 = openai_client.send_omni_request(request_config, request_num=1)[0] + + # We should cache the vast majority of the prompt (image + up to last full block), + # and set seed + temperature, so the second request should give an identical + # response for the generated input image, even if we use dummy weights + assert response_2.cached_tokens is not None and response_2.cached_tokens > 0 + assert response_1.text_content == response_2.text_content diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py index 9601057d44c..77cdecd8d02 100644 --- a/tests/e2e/online_serving/test_qwen3_omni_expansion.py +++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py @@ -8,7 +8,7 @@ import pytest -from tests.helpers.mark import hardware_test +from tests.helpers.mark import hardware_marks, hardware_test from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config @@ -34,7 +34,7 @@ LONG_AUDIO_DURATION_SEC = 120 -def get_batch_token_config(default_path): +def get_batch_token_config(default_path, *, stage_id: int = 1): """Override stage 1's max_num_batched_tokens to exercise small-batch paths. Uses the new flat-stage schema (``stages..``); the legacy @@ -44,7 +44,7 @@ def get_batch_token_config(default_path): return modify_stage_config( default_path, updates={ - "stages": {1: {"max_num_batched_tokens": 64}}, + "stages": {stage_id: {"max_num_batched_tokens": 64}}, }, ) @@ -61,39 +61,62 @@ def get_async_chunk_config(default_path): return modify_stage_config( default_path, updates={ + "async_chunk": True, "stages": {0: {"default_sampling_params.max_tokens": 2048}}, }, ) # CI deploy YAML (single file; xpu deltas applied via ``platforms:`` section). -# The overlay explicitly sets ``async_chunk: False``, so ``default`` tests the -# sync path and ``async_chunk`` tests the streaming path with a longer thinker -# output — two distinct scenarios, kept as separate parametrizations. +# Keep the three Qwen3-Omni launch modes independent: +# * default -> non-PD CI overlay +# * async_chunk -> non-PD overlay with async_chunk enabled +# * pd_default -> PD-specific CI overlay (3-GPU layout) default_path = get_deploy_config_path("ci/qwen3_omni_moe.yaml") +async_chunk_path = get_async_chunk_config(default_path) +pd_path = get_deploy_config_path("ci/qwen3_omni_moe_pd.yaml") test_params = [ pytest.param( OmniServerParams( - model=model, stage_config_path=default_path, use_stage_cli=True, server_args=["--no-async-chunk"] + model=model, + stage_config_path=default_path, + use_stage_cli=True, + server_args=["--no-async-chunk"], ), id="default", + marks=hardware_marks(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2), ), pytest.param( OmniServerParams( model=model, - stage_config_path=get_async_chunk_config(default_path), + stage_config_path=async_chunk_path, use_stage_cli=True, - server_args=["--async-chunk"], ), id="async_chunk", + marks=hardware_marks(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2), + ), + pytest.param( + OmniServerParams( + model=model, + stage_config_path=pd_path, + use_stage_cli=True, + server_args=["--no-async-chunk"], + ), + id="pd_default", + marks=hardware_marks(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3), ), ] test_token_params = [ pytest.param( - OmniServerParams(model=model, stage_config_path=get_batch_token_config(default_path), use_stage_cli=True), + OmniServerParams( + model=model, + stage_config_path=get_batch_token_config(default_path, stage_id=1), + use_stage_cli=True, + ), id="batch_token_64", + marks=hardware_marks(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2), ) ] diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py index c33e23b05fc..c6060498a13 100644 --- a/tests/engine/test_orchestrator.py +++ b/tests/engine/test_orchestrator.py @@ -609,3 +609,75 @@ async def test_run_abort(orchestrator_factory) -> None: assert "req-abort" not in orchestrator_fixture.orchestrator.request_states finally: await _shutdown_orchestrator(orchestrator_fixture) + + +def test_build_pd_decode_params_uses_transfer_id_bootstrap_and_engine_id() -> None: + request_queue = janus.Queue() + output_queue = janus.Queue() + rpc_queue = janus.Queue() + orchestrator = Orchestrator( + request_async_queue=request_queue.async_q, + output_async_queue=output_queue.async_q, + rpc_async_queue=rpc_queue.async_q, + stage_clients=[FakeStageClient(), FakeStageClient()], + output_processors=[FakeOutputProcessor(), FakeOutputProcessor()], + stage_vllm_configs=[ + SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + ], + pd_config={ + "pd_pair": (0, 1), + "bootstrap_addr": "127.0.0.1:25201", + "prefill_engine_id": "prefill-engine-0", + }, + ) + sp = _sampling_params() + + try: + result = orchestrator._build_pd_decode_params("req-pd", sp) + kv_params = result.extra_args["kv_transfer_params"] + + assert kv_params["transfer_id"] == "xfer-req-pd" + assert kv_params["remote_bootstrap_addr"] == "127.0.0.1:25201" + assert kv_params["remote_engine_id"] == "prefill-engine-0" + assert kv_params["do_remote_prefill"] is True + assert kv_params["do_remote_decode"] is False + finally: + for q in (request_queue, output_queue, rpc_queue): + q.close() + + +def test_build_pd_decode_params_preserves_prefill_overlay_fields() -> None: + request_queue = janus.Queue() + output_queue = janus.Queue() + rpc_queue = janus.Queue() + orchestrator = Orchestrator( + request_async_queue=request_queue.async_q, + output_async_queue=output_queue.async_q, + rpc_async_queue=rpc_queue.async_q, + stage_clients=[FakeStageClient(), FakeStageClient()], + output_processors=[FakeOutputProcessor(), FakeOutputProcessor()], + stage_vllm_configs=[ + SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + ], + pd_config={ + "pd_pair": (0, 1), + "bootstrap_addr": "127.0.0.1:25201", + "prefill_engine_id": "prefill-engine-0", + }, + ) + orchestrator._pd_kv_params["req-pd"] = {"remote_request_id": "legacy-prefill-id"} + sp = _sampling_params() + + try: + result = orchestrator._build_pd_decode_params("req-pd", sp) + kv_params = result.extra_args["kv_transfer_params"] + + assert kv_params["transfer_id"] == "xfer-req-pd" + assert kv_params["remote_bootstrap_addr"] == "127.0.0.1:25201" + assert kv_params["remote_engine_id"] == "prefill-engine-0" + assert kv_params["remote_request_id"] == "legacy-prefill-id" + finally: + for q in (request_queue, output_queue, rpc_queue): + q.close() diff --git a/tests/entrypoints/test_pd_disaggregation.py b/tests/entrypoints/test_pd_disaggregation.py index 5ffabfbf2af..3704e109aad 100644 --- a/tests/entrypoints/test_pd_disaggregation.py +++ b/tests/entrypoints/test_pd_disaggregation.py @@ -22,8 +22,6 @@ from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin -pytestmark = pytest.mark.skip(reason="Temporarily skip PD entrypoint tests while PD config is being removed.") - # Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. warnings.filterwarnings( "ignore", @@ -1081,41 +1079,52 @@ def test_pop_uses_fallback_when_no_stored(self, monkeypatch): class TestPDYAMLConfig: - def test_pd_yaml_loads(self): - """The PD separation YAML config should load without errors.""" + def test_pd_yaml_loads(self, tmp_path): + """PD deploy config should merge into a 4-stage runtime config.""" import os - yaml_path = os.path.join( - os.path.dirname(__file__), - "../../vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml", + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + from vllm_omni.config.stage_config import _PIPELINE_REGISTRY, load_deploy_config, merge_pipeline_deploy + + base_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../vllm_omni/deploy/qwen3_omni_moe.yaml") ) - yaml_path = os.path.abspath(yaml_path) - if not os.path.exists(yaml_path): - pytest.skip("PD separation YAML not found") + if not os.path.exists(base_path): + pytest.skip("Qwen3-Omni deploy config not found") - from omegaconf import OmegaConf + overlay = tmp_path / "qwen3_omni_pd_overlay.yaml" + overlay.write_text( + f"base_config: {base_path}\npd_disaggregation:\n enabled: true\n async_chunk: false\n", + encoding="utf-8", + ) - cfg = OmegaConf.load(yaml_path) - stages = cfg.stage_args + deploy = load_deploy_config(overlay) + stages = merge_pipeline_deploy(_PIPELINE_REGISTRY["qwen3_omni_moe"], deploy) assert len(stages) == 4 # Prefill stage - assert stages[0].is_prefill_only is True + assert stages[0].yaml_extras["is_prefill_only"] is True assert stages[0].final_output is False - assert stages[0].is_comprehension is True + assert stages[0].is_comprehension is False # Decode stage - assert stages[1].is_decode_only is True + assert stages[1].yaml_extras["is_decode_only"] is True assert stages[1].final_output is True assert stages[1].final_output_type == "text" assert stages[1].is_comprehension is True - assert 0 in stages[1].engine_input_source + assert 0 in stages[1].input_sources + + # Common 3-GPU PD layout + assert stages[0].yaml_runtime["devices"] == "0" + assert stages[1].yaml_runtime["devices"] == "1" + assert stages[2].yaml_runtime["devices"] == "2" + assert stages[3].yaml_runtime["devices"] == "2" # KV transfer configs - assert stages[0].engine_args.kv_transfer_config.kv_role == "kv_producer" - assert stages[1].engine_args.kv_transfer_config.kv_role == "kv_consumer" - assert stages[0].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector" - assert stages[1].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector" + assert stages[0].yaml_engine_args["kv_transfer_config"]["kv_role"] == "kv_producer" + assert stages[1].yaml_engine_args["kv_transfer_config"]["kv_role"] == "kv_consumer" + assert stages[0].yaml_engine_args["kv_transfer_config"]["kv_connector"] == "MooncakeConnector" + assert stages[1].yaml_engine_args["kv_transfer_config"]["kv_connector"] == "MooncakeConnector" class TestPrefillStopNeutralization: diff --git a/tests/helpers/runtime.py b/tests/helpers/runtime.py index 79ca9ee5cf5..a6e13147cb2 100644 --- a/tests/helpers/runtime.py +++ b/tests/helpers/runtime.py @@ -41,6 +41,7 @@ decode_b64_image, ) from vllm_omni.config.stage_config import resolve_deploy_yaml +from vllm_omni.entrypoints.utils import detect_explicit_cli_keys, load_and_resolve_stage_configs from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) @@ -381,8 +382,21 @@ def __init__( f"{yaml.safe_dump(resolved_cfg, sort_keys=False, default_flow_style=False)}", flush=True, ) - self.stage_runtime_devices = self._load_stage_runtime_devices(resolved_cfg) - self.stage_ids = stage_ids or self._load_stage_ids(resolved_cfg) + runtime_stage_configs = self._load_runtime_stage_configs(model, stage_config_path, self.serve_args) + runtime_layout = [ + { + "stage_id": self._stage_id(stage_cfg), + "devices": self._runtime_devices(stage_cfg), + } + for stage_cfg in runtime_stage_configs + ] + print( + f"[OmniServerStageCli] Runtime stage layout from {stage_config_path}:\n" + f"{yaml.safe_dump(runtime_layout, sort_keys=False, default_flow_style=False)}", + flush=True, + ) + self.stage_runtime_devices = self._load_stage_runtime_devices(runtime_stage_configs) + self.stage_ids = stage_ids or self._load_stage_ids(runtime_stage_configs) if 0 not in self.stage_ids: raise ValueError(f"Stage CLI test requires stage_id=0 in config: {stage_config_path}") self.stage_procs: dict[int, subprocess.Popen] = {} @@ -395,24 +409,82 @@ def _stage_entries(cfg: dict) -> list[dict]: return cfg.get("stage_args") or cfg.get("stages") or [] @staticmethod - def _load_stage_ids(resolved_config: dict) -> list[int]: - stage_ids = [ - stage["stage_id"] for stage in OmniServerStageCli._stage_entries(resolved_config) if "stage_id" in stage - ] + def _stage_id(stage: Any) -> int | None: + if isinstance(stage, dict): + stage_id = stage.get("stage_id") + else: + stage_id = getattr(stage, "stage_id", None) + return int(stage_id) if stage_id is not None else None + + @staticmethod + def _runtime_devices(stage: Any) -> str | None: + if isinstance(stage, dict): + devices = stage.get("devices") or stage.get("runtime", {}).get("devices") + return str(devices) if devices else None + + runtime = getattr(stage, "runtime", None) + if runtime is not None: + if hasattr(runtime, "get"): + devices = runtime.get("devices") + else: + devices = getattr(runtime, "devices", None) + if devices: + return str(devices) + + runtime_overrides = getattr(stage, "runtime_overrides", None) or {} + yaml_runtime = getattr(stage, "yaml_runtime", None) or {} + devices = runtime_overrides.get("devices") or yaml_runtime.get("devices") + return str(devices) if devices else None + + @classmethod + def _load_runtime_stage_configs(cls, model: str, stage_config_path: str, serve_args: list[str]) -> list[Any]: + cli_argv = ["serve", model, "--omni", *serve_args] + cli_kwargs: dict[str, Any] = {"_cli_explicit_keys": detect_explicit_cli_keys(cli_argv)} + try: + from vllm.utils.argparse_utils import FlexibleArgumentParser + + from vllm_omni.entrypoints.cli.serve import OmniServeCommand + + root = FlexibleArgumentParser(add_help=False) + subparsers = root.add_subparsers(dest="subcommand") + cmd = OmniServeCommand() + serve_parser = cmd.subparser_init(subparsers) + args, _ = root.parse_known_args(cli_argv) + cli_kwargs = vars(args).copy() + cli_kwargs.pop("subcommand", None) + cli_kwargs["_cli_explicit_keys"] = detect_explicit_cli_keys(cli_argv, serve_parser) + except Exception: + # Parser construction is best-effort in tests; the heuristic + # explicit-key path still lets deploy YAMLs drive runtime layout. + pass + + _, stage_configs = load_and_resolve_stage_configs( + model=model, + stage_configs_path=stage_config_path, + kwargs=cli_kwargs, + ) + return stage_configs + + @classmethod + def _load_stage_ids(cls, stage_source: dict | list[Any]) -> list[int]: + if isinstance(stage_source, dict): + stage_ids = [cls._stage_id(stage) for stage in cls._stage_entries(stage_source)] + else: + stage_ids = [cls._stage_id(stage) for stage in stage_source] + stage_ids = [stage_id for stage_id in stage_ids if stage_id is not None] if not stage_ids: raise ValueError("No stage IDs found in resolved config") return stage_ids - @staticmethod - def _load_stage_runtime_devices(resolved_config: dict) -> dict[int, str]: + @classmethod + def _load_stage_runtime_devices(cls, stage_source: dict | list[Any]) -> dict[int, str]: runtime_devices: dict[int, str] = {} - for stage in OmniServerStageCli._stage_entries(resolved_config): - stage_id = stage.get("stage_id") - # New schema: stage.devices is flat at stage level. - # Legacy schema: stage.runtime.devices is nested. - devices = stage.get("devices") or stage.get("runtime", {}).get("devices") + stage_entries = cls._stage_entries(stage_source) if isinstance(stage_source, dict) else stage_source + for stage in stage_entries: + stage_id = cls._stage_id(stage) + devices = cls._runtime_devices(stage) if stage_id is not None and devices: - runtime_devices[int(stage_id)] = str(devices) + runtime_devices[stage_id] = devices return runtime_devices @classmethod diff --git a/tests/helpers/stage_config.py b/tests/helpers/stage_config.py index 66a58378d32..ccb690cb73f 100644 --- a/tests/helpers/stage_config.py +++ b/tests/helpers/stage_config.py @@ -328,6 +328,9 @@ def delete_by_path(config_dict: dict, path: str) -> None: "qwen3_omni_moe": { "base_config": "qwen3_omni_moe.yaml", "async_chunk": False, + "pd_disaggregation": { + "enabled": False, + }, "stages": [ { "stage_id": 0, @@ -484,6 +487,161 @@ def delete_by_path(config_dict: dict, path: str) -> None: }, }, }, + "qwen3_omni_moe_pd": { + "base_config": "qwen3_omni_moe.yaml", + "async_chunk": False, + "stages": [ + { + "stage_id": 0, + "max_num_seqs": 5, + "max_model_len": 32768, + "mm_processor_cache_gb": 0, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 150, "ignore_eos": False}, + }, + { + "stage_id": 1, + "gpu_memory_utilization": 0.5, + "max_num_seqs": 5, + "max_model_len": 32768, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 1000}, + }, + { + "stage_id": 2, + "max_num_seqs": 5, + "max_num_batched_tokens": 100000, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 2000}, + }, + ], + "pd_disaggregation": { + "enabled": True, + "target_stage_id": 0, + "async_chunk": False, + "stages": [ + { + "role": "prefill", + "max_num_seqs": 5, + "devices": "0", + "tensor_parallel_size": 1, + "default_sampling_params": { + "temperature": 0.4, + "top_p": 0.9, + "top_k": 1, + "max_tokens": 150, + "seed": 42, + "repetition_penalty": 1.05, + "ignore_eos": False, + }, + "engine_extras": { + "kv_transfer_config": { + "kv_connector": "MooncakeConnector", + "kv_role": "kv_producer", + "kv_rank": 0, + "kv_parallel_size": 2, + "kv_connector_extra_config": { + "mooncake_bootstrap_port": 25201, + }, + }, + }, + }, + { + "role": "decode", + "max_num_seqs": 5, + "devices": "1", + "tensor_parallel_size": 1, + "default_sampling_params": { + "temperature": 0.4, + "top_p": 0.9, + "top_k": 1, + "max_tokens": 1024, + "seed": 42, + "repetition_penalty": 1.05, + "ignore_eos": False, + }, + "engine_extras": { + "kv_transfer_config": { + "kv_connector": "MooncakeConnector", + "kv_role": "kv_consumer", + "kv_rank": 1, + "kv_parallel_size": 2, + "kv_connector_extra_config": { + "mooncake_bootstrap_port": 25202, + }, + }, + }, + }, + ], + "stage_overrides": [ + { + "stage_id": 1, + "devices": "2", + }, + { + "stage_id": 2, + "devices": "2", + }, + ], + }, + "platforms": { + "rocm": { + "stages": [ + {"stage_id": 0, "max_num_seqs": 1, "default_sampling_params": {"max_tokens": 100}}, + { + "stage_id": 1, + "max_num_seqs": 1, + "enforce_eager": True, + "default_sampling_params": {"max_tokens": 100}, + }, + { + "stage_id": 2, + "max_num_seqs": 1, + "max_num_batched_tokens": 1000000, + "default_sampling_params": {"max_tokens": 200}, + }, + ], + }, + "xpu": { + "stages": [ + { + "stage_id": 0, + "gpu_memory_utilization": 0.85, + "max_num_seqs": 1, + "tensor_parallel_size": 4, + "enforce_eager": True, + "max_num_batched_tokens": 4096, + "max_model_len": 4096, + "max_cudagraph_capture_size": 0, + "skip_mm_profiling": True, + "devices": "0,1,2,3", + "default_sampling_params": {"max_tokens": 100, "ignore_eos": False}, + }, + { + "stage_id": 1, + "gpu_memory_utilization": 0.6, + "max_num_seqs": 1, + "enforce_eager": True, + "max_num_batched_tokens": 4096, + "max_model_len": 4096, + "max_cudagraph_capture_size": 0, + "skip_mm_profiling": True, + "devices": "4", + }, + { + "stage_id": 2, + "gpu_memory_utilization": 0.3, + "max_num_seqs": 1, + "max_num_batched_tokens": 100000, + "max_cudagraph_capture_size": 0, + "skip_mm_profiling": True, + "devices": "5", + "default_sampling_params": {"max_tokens": 2000}, + }, + ], + }, + }, + }, "ming_flash_omni": { "base_config": "ming_flash_omni.yaml", "stages": [ diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 7abe8fc8693..47e64d9c084 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -842,6 +842,98 @@ def test_merge_pipeline_deploy(self): assert s0.yaml_engine_args["engine_output_type"] == "latent" assert s0.yaml_extras["default_sampling_params"]["detokenize"] is True + def test_merge_pipeline_deploy_with_pd_disaggregation(self, tmp_path): + from pathlib import Path + + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy + + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not base.exists(): + pytest.skip("Deploy config not found") + + overlay = tmp_path / "qwen3_omni_pd.yaml" + overlay.write_text( + f"base_config: {base}\n" + "pd_disaggregation:\n" + " enabled: true\n" + " target_stage_id: 0\n" + " async_chunk: false\n" + " stages:\n" + " - role: prefill\n" + " max_num_seqs: 16\n" + ' devices: "0"\n' + " engine_extras:\n" + " kv_transfer_config:\n" + " kv_connector: MooncakeConnector\n" + " kv_role: kv_producer\n" + " kv_rank: 0\n" + " kv_parallel_size: 2\n" + " - role: decode\n" + " max_num_seqs: 64\n" + ' devices: "1"\n' + " engine_extras:\n" + " kv_transfer_config:\n" + " kv_connector: MooncakeConnector\n" + " kv_role: kv_consumer\n" + " kv_rank: 1\n" + " kv_parallel_size: 2\n", + encoding="utf-8", + ) + + deploy = load_deploy_config(overlay) + stages = merge_pipeline_deploy(pipeline, deploy) + + assert len(stages) == 4 + assert stages[0].yaml_extras["is_prefill_only"] is True + assert stages[1].yaml_extras["is_decode_only"] is True + assert stages[0].is_comprehension is False + assert stages[1].is_comprehension is True + assert stages[1].input_sources == [0] + assert stages[2].input_sources == [1] + assert stages[3].input_sources == [2] + assert stages[0].yaml_runtime["devices"] == "0" + assert stages[1].yaml_runtime["devices"] == "1" + assert stages[2].yaml_runtime["devices"] == "2" + assert stages[3].yaml_runtime["devices"] == "2" + assert stages[0].yaml_engine_args.get("async_chunk") is not True + assert stages[1].yaml_engine_args.get("custom_process_next_stage_input_func") is None + assert stages[0].yaml_engine_args["kv_transfer_config"]["kv_role"] == "kv_producer" + assert stages[1].yaml_engine_args["kv_transfer_config"]["kv_role"] == "kv_consumer" + assert stages[2].yaml_extras["input_connectors"] == {"from_stage_1": "connector_of_shared_memory"} + assert stages[3].yaml_extras["input_connectors"] == {"from_stage_2": "connector_of_shared_memory"} + + def test_merge_pipeline_deploy_with_pd_disaggregation_cli_flag(self): + from pathlib import Path + + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy + + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") + + deploy = load_deploy_config(deploy_path) + stages = merge_pipeline_deploy( + pipeline, + deploy, + cli_overrides={ + "enable_pd_disaggregation": True, + "stage_0_max_num_seqs": 8, + "stage_1_max_num_seqs": 32, + }, + ) + + assert len(stages) == 4 + assert stages[0].yaml_extras["is_prefill_only"] is True + assert stages[1].yaml_extras["is_decode_only"] is True + assert stages[2].yaml_runtime["devices"] == "2" + assert stages[3].yaml_runtime["devices"] == "2" + assert stages[0].runtime_overrides["max_num_seqs"] == 8 + assert stages[1].runtime_overrides["max_num_seqs"] == 32 + class TestQwen3OmniPipeline: def test_registered(self): diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index d4e33667723..dbfe68a3bfc 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -434,6 +434,9 @@ class DeployConfig: platforms: dict[str, Any] | None = None # Overrides the auto-detected pipeline registry key for structural variants. pipeline: str | None = None + # PD (Prefill-Decode) disaggregation configuration. + # When enabled, dynamically splits the target stage into prefill and decode stages. + pd_disaggregation: dict[str, Any] | None = None # === Pipeline-wide engine settings (applied uniformly to every stage) === trust_remote_code: bool = True @@ -542,6 +545,39 @@ def _merge_platforms( return merged +def _merge_pd_stage_role_lists( + base_stages: list[dict[str, Any]] | None, + overlay_stages: list[dict[str, Any]] | None, +) -> list[dict[str, Any]]: + """Merge PD ``stages:`` by ``role`` (prefill/decode).""" + by_role: dict[str, dict[str, Any]] = {s["role"]: s for s in (base_stages or []) if "role" in s} + passthrough: list[dict[str, Any]] = [s for s in (base_stages or []) if "role" not in s] + for overlay_stage in overlay_stages or []: + role = overlay_stage.get("role") + if role is not None and role in by_role: + by_role[role] = _deep_merge_stage(by_role[role], overlay_stage) + elif role is not None: + by_role[role] = overlay_stage + else: + passthrough.append(overlay_stage) + return [*by_role.values(), *passthrough] + + +def _merge_pd_disaggregation( + base: dict[str, Any] | None, + overlay: dict[str, Any] | None, +) -> dict[str, Any] | None: + """Deep-merge two ``pd_disaggregation:`` blocks.""" + if not base and not overlay: + return None + base = base or {} + overlay = overlay or {} + merged = {**base, **{k: v for k, v in overlay.items() if k not in ("stages", "stage_overrides")}} + merged["stages"] = _merge_pd_stage_role_lists(base.get("stages"), overlay.get("stages")) + merged["stage_overrides"] = _merge_stage_lists(base.get("stage_overrides"), overlay.get("stage_overrides")) + return merged + + def resolve_deploy_yaml(path: str | Path) -> dict[str, Any]: """Load a deploy YAML with optional ``base_config`` inheritance.""" raw_dict = to_dict(load_yaml_config(path)) @@ -554,16 +590,23 @@ def resolve_deploy_yaml(path: str | Path) -> dict[str, Any]: base_path = Path(path).parent / base_path base_dict = resolve_deploy_yaml(base_path) - # Merge top-level scalars: overlay wins. ``stages:`` and ``platforms:`` - # are deep-merged below so an overlay can layer on top of the base. + # Merge top-level scalars: overlay wins. ``stages:``, ``platforms:``, and + # ``pd_disaggregation:`` are deep-merged below so an overlay can layer on + # top of the base without restating every nested field. merged = { **base_dict, - **{k: v for k, v in raw_dict.items() if k not in ("stages", "platforms")}, + **{k: v for k, v in raw_dict.items() if k not in ("stages", "platforms", "pd_disaggregation", "pd_separation")}, } merged["stages"] = _merge_stage_lists(base_dict.get("stages"), raw_dict.get("stages")) merged_platforms = _merge_platforms(base_dict.get("platforms"), raw_dict.get("platforms")) if merged_platforms is not None: merged["platforms"] = merged_platforms + merged_pd = _merge_pd_disaggregation( + base_dict.get("pd_disaggregation", base_dict.get("pd_separation")), + raw_dict.get("pd_disaggregation", raw_dict.get("pd_separation")), + ) + if merged_pd is not None: + merged["pd_disaggregation"] = merged_pd return merged @@ -581,6 +624,7 @@ def load_deploy_config(path: str | Path) -> DeployConfig: "stages": stages, "platforms": raw_dict.get("platforms", None), "pipeline": raw_dict.get("pipeline", None), + "pd_disaggregation": raw_dict.get("pd_disaggregation", raw_dict.get("pd_separation", None)), } # Pipeline-wide engine settings: only set if explicitly present in YAML # so the DeployConfig dataclass defaults take effect otherwise. @@ -645,6 +689,240 @@ def _apply_platform_overrides( return deploy +_STAGE_CONNECTOR_PATTERN = re.compile(r"^(from_stage_|to_stage_)(\d+)$") + + +def _remap_split_stage_id(stage_id: int, target_stage_id: int) -> int: + """Shift stage IDs after inserting a decode stage after ``target_stage_id``.""" + return stage_id + 1 if stage_id > target_stage_id else stage_id + + +def _remap_split_source_id(source_id: int, target_stage_id: int) -> int: + """Remap a dependency/source edge across a PD thinker split. + + Downstream consumers that previously depended on the original thinker stage + should now consume the decode stage (``target_stage_id + 1``). + """ + if source_id == target_stage_id: + return target_stage_id + 1 + if source_id > target_stage_id: + return source_id + 1 + return source_id + + +def _remap_stage_connectors( + connectors: dict[str, str] | None, + target_stage_id: int, +) -> dict[str, str] | None: + """Rewrite ``from_stage_N`` / ``to_stage_N`` connector keys after a PD split.""" + if not connectors: + return connectors + + remapped: dict[str, str] = {} + for key, value in connectors.items(): + match = _STAGE_CONNECTOR_PATTERN.match(key) + if match is None: + remapped[key] = value + continue + prefix, raw_stage_id = match.groups() + ref_stage_id = _remap_split_source_id(int(raw_stage_id), target_stage_id) + remapped[f"{prefix}{ref_stage_id}"] = value + return remapped + + +def _pd_role_override(pd_cfg: dict[str, Any], role: str) -> dict[str, Any]: + """Return deploy overrides for a PD role from ``pd_disaggregation`` config.""" + role_cfg = pd_cfg.get(f"{role}_stage") + if isinstance(role_cfg, dict): + return dict(role_cfg) + + for entry in pd_cfg.get("stages", []) or []: + if entry.get("role") == role: + return dict(entry) + return {} + + +def _pd_stage_override(pd_cfg: dict[str, Any], stage_id: int) -> dict[str, Any]: + """Return PD-only deploy overrides for one original stage_id.""" + stage_cfg = pd_cfg.get(f"stage_{stage_id}") + if isinstance(stage_cfg, dict): + return dict(stage_cfg) + + stage_overrides = pd_cfg.get("stage_overrides") + if isinstance(stage_overrides, dict): + for key in (stage_id, str(stage_id)): + entry = stage_overrides.get(key) + if isinstance(entry, dict): + return dict(entry) + + for entry in stage_overrides or []: + if entry.get("stage_id") == stage_id: + return dict(entry) + return {} + + +def _make_pd_stage_deploy( + base: StageDeployConfig, + *, + stage_id: int, + stage_override: dict[str, Any], + target_stage_id: int, +) -> StageDeployConfig: + """Build one PD-mode stage deploy config from the original base stage.""" + override = dict(stage_override) + override.pop("role", None) + override.pop("stage_id", None) + override_engine_extras = dict(override.pop("engine_extras", {}) or {}) + + return dataclasses.replace( + base, + stage_id=stage_id, + max_num_seqs=override.pop("max_num_seqs", base.max_num_seqs), + gpu_memory_utilization=override.pop("gpu_memory_utilization", base.gpu_memory_utilization), + tensor_parallel_size=override.pop("tensor_parallel_size", base.tensor_parallel_size), + enforce_eager=override.pop("enforce_eager", base.enforce_eager), + max_num_batched_tokens=override.pop("max_num_batched_tokens", base.max_num_batched_tokens), + max_model_len=override.pop("max_model_len", base.max_model_len), + async_scheduling=override.pop("async_scheduling", base.async_scheduling), + devices=override.pop("devices", base.devices), + output_connectors=_remap_stage_connectors( + override.pop("output_connectors", base.output_connectors), + target_stage_id, + ), + input_connectors=_remap_stage_connectors( + override.pop("input_connectors", base.input_connectors), + target_stage_id, + ), + default_sampling_params=override.pop("default_sampling_params", base.default_sampling_params), + engine_extras={ + **base.engine_extras, + **override_engine_extras, + **override, + }, + ) + + +def _apply_pd_disaggregation( + pipeline: PipelineConfig, + deploy: DeployConfig, + cli_overrides: dict[str, Any] | None = None, +) -> tuple[PipelineConfig, DeployConfig]: + """Optionally split one pipeline stage into prefill + decode stages.""" + pd_cfg = dict(deploy.pd_disaggregation or {}) + if cli_overrides is not None and cli_overrides.get("enable_pd_disaggregation") is not None: + pd_cfg["enabled"] = bool(cli_overrides["enable_pd_disaggregation"]) + if not pd_cfg.get("enabled", False): + return pipeline, deploy + + target_stage_id = int(pd_cfg.get("target_stage_id", 0)) + target_stage = pipeline.get_stage(target_stage_id) + if target_stage is None: + raise ValueError(f"PD {target_stage_id} not found in pipeline {pipeline.model_type!r}") + if target_stage.execution_type != StageExecutionType.LLM_AR: + raise ValueError( + f"PD disaggregation only supports LLM_AR stages; stage {target_stage_id} " + f"uses {target_stage.execution_type.value!r}" + ) + + deploy_by_id = {stage.stage_id: stage for stage in deploy.stages} + target_deploy = deploy_by_id.get(target_stage_id) + if target_deploy is None: + raise ValueError(f"PD disaggregation target stage {target_stage_id} missing deploy settings") + + prefill_override = _pd_role_override(pd_cfg, "prefill") + decode_override = _pd_role_override(pd_cfg, "decode") + if not prefill_override or not decode_override: + raise ValueError("PD disaggregation requires both prefill and decode role configs") + + prefill_stage = dataclasses.replace( + target_stage, + final_output=False, + final_output_type=None, + owns_tokenizer=False, + custom_process_input_func=None, + custom_process_next_stage_input_func=None, + async_chunk_process_next_stage_input_func=None, + sync_process_input_func=None, + extras={**target_stage.extras, "is_prefill_only": True}, + ) + decode_stage = dataclasses.replace( + target_stage, + stage_id=target_stage_id + 1, + input_sources=(target_stage_id,), + owns_tokenizer=target_stage.owns_tokenizer, + custom_process_next_stage_input_func=None, + async_chunk_process_next_stage_input_func=None, + sync_process_input_func=None, + extras={**target_stage.extras, "is_decode_only": True}, + ) + + new_pipeline_stages: list[StagePipelineConfig] = [] + for stage in pipeline.stages: + if stage.stage_id < target_stage_id: + new_pipeline_stages.append(stage) + continue + if stage.stage_id == target_stage_id: + new_pipeline_stages.extend((prefill_stage, decode_stage)) + continue + + new_pipeline_stages.append( + dataclasses.replace( + stage, + stage_id=_remap_split_stage_id(stage.stage_id, target_stage_id), + input_sources=tuple( + _remap_split_source_id(source_id, target_stage_id) for source_id in stage.input_sources + ), + ) + ) + + new_deploy_stages: list[StageDeployConfig] = [] + for stage in deploy.stages: + if stage.stage_id < target_stage_id: + new_deploy_stages.append(stage) + continue + if stage.stage_id == target_stage_id: + new_deploy_stages.append( + _make_pd_stage_deploy( + target_deploy, + stage_id=target_stage_id, + stage_override=prefill_override, + target_stage_id=target_stage_id, + ) + ) + new_deploy_stages.append( + _make_pd_stage_deploy( + target_deploy, + stage_id=target_stage_id + 1, + stage_override=decode_override, + target_stage_id=target_stage_id, + ) + ) + continue + + new_deploy_stages.append( + _make_pd_stage_deploy( + stage, + stage_id=_remap_split_stage_id(stage.stage_id, target_stage_id), + stage_override=_pd_stage_override(pd_cfg, stage.stage_id), + target_stage_id=target_stage_id, + ) + ) + + effective_async_chunk = bool(pd_cfg.get("async_chunk", False)) + if cli_overrides is not None and cli_overrides.get("async_chunk") is not None: + effective_async_chunk = bool(cli_overrides["async_chunk"]) + + return ( + dataclasses.replace(pipeline, stages=tuple(new_pipeline_stages)), + dataclasses.replace( + deploy, + async_chunk=effective_async_chunk, + stages=new_deploy_stages, + pd_disaggregation=None, + ), + ) + + _EXECUTION_TYPE_TO_STAGE_WORKER: dict[StageExecutionType, tuple[StageType, str | None]] = { StageExecutionType.LLM_AR: (StageType.LLM, "ar"), StageExecutionType.LLM_GENERATION: (StageType.LLM, "generation"), @@ -769,6 +1047,7 @@ def merge_pipeline_deploy( cli_overrides = {} deploy = _apply_platform_overrides(deploy) + pipeline, deploy = _apply_pd_disaggregation(pipeline, deploy, cli_overrides) deploy_by_id = {s.stage_id: s for s in deploy.stages} # A pipeline supports async_chunk if any stage has either an explicit diff --git a/vllm_omni/deploy/qwen3_omni_moe.yaml b/vllm_omni/deploy/qwen3_omni_moe.yaml index 39baed6bd7b..38de3c0ce3d 100644 --- a/vllm_omni/deploy/qwen3_omni_moe.yaml +++ b/vllm_omni/deploy/qwen3_omni_moe.yaml @@ -60,6 +60,45 @@ stages: seed: 42 repetition_penalty: 1.1 +pd_disaggregation: + enabled: false + target_stage_id: 0 + async_chunk: false + # PD mode uses the common 3-GPU layout: + # GPU 0 -> thinker prefill + # GPU 1 -> thinker decode + # GPU 2 -> talker + code2wav + stages: + - role: prefill + max_num_seqs: 16 + devices: "0" + tensor_parallel_size: 1 + engine_extras: + kv_transfer_config: + kv_connector: "MooncakeConnector" + kv_role: "kv_producer" + kv_rank: 0 + kv_parallel_size: 2 + kv_connector_extra_config: + mooncake_bootstrap_port: 25201 + - role: decode + max_num_seqs: 64 + devices: "1" + tensor_parallel_size: 1 + engine_extras: + kv_transfer_config: + kv_connector: "MooncakeConnector" + kv_role: "kv_consumer" + kv_rank: 1 + kv_parallel_size: 2 + kv_connector_extra_config: + mooncake_bootstrap_port: 25202 + stage_overrides: + - stage_id: 1 + devices: "2" + - stage_id: 2 + devices: "2" + platforms: npu: stages: diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 3f16c329e27..1dbc40ab81f 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -427,6 +427,7 @@ class OrchestratorArgs: # === Mode Switches (orchestrator reads, DeployConfig redistributes) === async_chunk: bool | None = None + enable_pd_disaggregation: bool | None = None # === Observability === log_stats: bool = False diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 66096de7c0c..d8f43a5642d 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1233,13 +1233,6 @@ def _detect_pd_config(self) -> dict[str, Any] | None: bootstrap_addr = f"http://{kv_ip}:{port}" except Exception as exc: logger.warning("[AsyncOmniEngine] Could not extract PD bootstrap address: %s", exc) - - logger.info( - "[AsyncOmniEngine] PD disaggregation detected: prefill=stage-%d, decode=stage-%d, bootstrap=%s", - prefill_idx, - decode_idx, - bootstrap_addr, - ) prefill_engine_id: str | None = None try: prefill_client = self.stage_clients[prefill_idx] @@ -1248,6 +1241,13 @@ def _detect_pd_config(self) -> dict[str, Any] | None: except Exception as exc: logger.warning("[AsyncOmniEngine] Could not extract prefill engine_id: %s", exc) + logger.info( + "[AsyncOmniEngine] PD disaggregation detected: prefill=stage-%d, decode=stage-%d, bootstrap=%s", + prefill_idx, + decode_idx, + bootstrap_addr, + ) + return { "pd_pair": (prefill_idx, decode_idx), "bootstrap_addr": bootstrap_addr, diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index b53bdae84d4..82493aad71c 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -12,6 +12,7 @@ import copy import time as _time from dataclasses import dataclass, field +from functools import lru_cache from typing import Any import janus @@ -38,6 +39,74 @@ logger = init_logger(__name__) +@lru_cache(maxsize=8) +def _load_tokenizer_eos_token_id( + tokenizer_name_or_path: str, + trust_remote_code: bool, + revision: str | None, +) -> int | None: + try: + from transformers import AutoTokenizer + except Exception: + return None + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, + trust_remote_code=trust_remote_code, + revision=revision, + ) + except Exception: + logger.exception( + "[Orchestrator] failed to load tokenizer for eos lookup: %s", + tokenizer_name_or_path, + ) + return None + + eos_token_id = getattr(tokenizer, "eos_token_id", None) + if isinstance(eos_token_id, int): + return eos_token_id + if isinstance(eos_token_id, (list, tuple)): + for token_id in eos_token_id: + if isinstance(token_id, int): + return token_id + return None + + +def _resolve_model_eos_token_id(model_config: ModelConfig | None) -> int | None: + if model_config is None: + return None + hf_text_cfg = getattr(model_config, "hf_text_config", None) + eos_token_id = getattr(hf_text_cfg, "eos_token_id", None) + if isinstance(eos_token_id, int): + return eos_token_id + if isinstance(eos_token_id, (list, tuple)): + for token_id in eos_token_id: + if isinstance(token_id, int): + return token_id + + hf_cfg = getattr(model_config, "hf_config", None) + eos_token_id = getattr(hf_cfg, "eos_token_id", None) + if isinstance(eos_token_id, int): + return eos_token_id + if isinstance(eos_token_id, (list, tuple)): + for token_id in eos_token_id: + if isinstance(token_id, int): + return token_id + + tokenizer_name_or_path = getattr(model_config, "tokenizer", None) or getattr(model_config, "model", None) + if isinstance(tokenizer_name_or_path, str) and tokenizer_name_or_path: + eos_token_id = _load_tokenizer_eos_token_id( + tokenizer_name_or_path=tokenizer_name_or_path, + trust_remote_code=bool(getattr(model_config, "trust_remote_code", False)), + revision=getattr(model_config, "tokenizer_revision", None), + ) + if isinstance(eos_token_id, int): + return eos_token_id + + return None + + def build_engine_core_request_from_tokens( request_id: str, prompt: dict[str, Any], @@ -444,17 +513,6 @@ async def _route_output( kv_params = getattr(output, "kv_transfer_params", None) if kv_params is not None: self._pd_kv_params[req_id] = kv_params if isinstance(kv_params, dict) else dict(kv_params) - logger.debug( - "[Orchestrator][PD] stored kv_transfer_params for req=%s (keys=%s)", - req_id, - list(self._pd_kv_params[req_id].keys()), - ) - else: - logger.warning( - "[Orchestrator][PD] prefill stage output for req=%s has no kv_transfer_params; " - "KV transfer may fail. Ensure apply_mooncake_connector_patch() was called.", - req_id, - ) if ( (finished or (req_state.streaming.enabled and req_state.streaming.segment_finished)) @@ -491,6 +549,7 @@ async def _route_output( if finished and stage_id == req_state.final_stage_id: # PD: clean up any lingering KV params for this request self._pd_kv_params.pop(req_id, None) + self._clear_pd_prefill_multimodal_output(req_id) self._cfg_tracker.cleanup_parent(req_id) self.request_states.pop(req_id, None) @@ -548,14 +607,15 @@ def _build_pd_decode_params(self, req_id: str, sp: Any) -> Any: if sp.extra_args is None: sp.extra_args = {} - # Get KV params captured from the prefill output (must include remote_request_id). - kv_prefill_params = self._pd_kv_params.pop(req_id, None) - if not kv_prefill_params or "remote_request_id" not in kv_prefill_params: - raise RuntimeError( - f"[Orchestrator][PD] Missing prefill kv_transfer_params.remote_request_id for req={req_id}" - ) + # Newer Mooncake routing keys requests by transfer_id and uses + # remote_engine_id + remote_bootstrap_addr to discover the producer. + # Prefill-side kv_transfer_params are therefore optional; when present + # we only treat them as an overlay/fallback. + kv_prefill_params = self._pd_kv_params.pop(req_id, None) or {} decode_kv_params: dict[str, Any] = { + "do_remote_prefill": True, + "do_remote_decode": False, "transfer_id": f"xfer-{req_id}", } @@ -565,9 +625,18 @@ def _build_pd_decode_params(self, req_id: str, sp: Any) -> Any: if self._pd_prefill_engine_id: decode_kv_params["remote_engine_id"] = self._pd_prefill_engine_id - # Overlay params from prefill side (includes remote_request_id set by monkey patch). + # Overlay any prefill-side params (legacy patches may still thread + # remote_request_id or other connector-specific fields through outputs). decode_kv_params.update(kv_prefill_params) + missing = [ + key for key in ("transfer_id", "remote_bootstrap_addr", "remote_engine_id") if not decode_kv_params.get(key) + ] + if missing: + raise RuntimeError( + f"[Orchestrator][PD] Missing decode kv_transfer_params fields for req={req_id}: {', '.join(missing)}" + ) + # Ensure these flags are set correctly after any overlay. decode_kv_params["do_remote_prefill"] = True decode_kv_params["do_remote_decode"] = False @@ -575,14 +644,15 @@ def _build_pd_decode_params(self, req_id: str, sp: Any) -> Any: decode_kv_params["transfer_id"] = f"xfer-{req_id}" sp.extra_args["kv_transfer_params"] = decode_kv_params - - logger.debug( - "[Orchestrator][PD] decode kv_transfer_params for req=%s: %s", - req_id, - decode_kv_params, - ) return sp + def _clear_pd_prefill_multimodal_output(self, req_id: str) -> None: + if self._pd_pair is None: + return + clear_fn = getattr(self.stage_clients[self._pd_pair[0]], "clear_pd_prefill_multimodal_output", None) + if callable(clear_fn): + clear_fn(req_id) + def _build_stage_metrics( self, stage_id: int, @@ -718,6 +788,7 @@ async def _forward_to_next_stage( } ) self._pd_kv_params.pop(req_id, None) + self._clear_pd_prefill_multimodal_output(req_id) self._cfg_tracker.cleanup_parent(req_id) self.request_states.pop(req_id, None) return @@ -748,8 +819,12 @@ async def _forward_to_next_stage( # PD disaggregation: prefill → decode routing uses original prompt + KV transfer params if self._pd_pair is not None and (stage_id, next_stage_id) == self._pd_pair: - # Save prefill stage outputs so thinker2talker can merge embeddings later - self.stage_clients[stage_id].set_engine_outputs([output]) + prefill_mm = None + try: + prefill_mm = output.outputs[0].multimodal_output + except Exception: + prefill_mm = None + self.stage_clients[stage_id].set_pd_prefill_multimodal_output(req_id, prefill_mm) params = self._build_pd_decode_params(req_id, params) @@ -879,6 +954,12 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut # and OutputProcessor would leak. self.output_processors[stage_id].abort_requests(processed.reqs_to_abort, internal=True) for req_id in processed.reqs_to_abort: + # Mirror final-stage / explicit abort cleanup for PD-specific + # transient state so aborted requests don't leak across tests. + self._pd_kv_params.pop(req_id, None) + self._pd_decode_prev_output_len.pop(req_id, None) + self._clear_pd_prefill_multimodal_output(req_id) + self._cfg_tracker.cleanup_parent(req_id) self.request_states.pop(req_id, None) if raw_outputs.scheduler_stats is not None: @@ -1095,6 +1176,7 @@ async def _handle_abort(self, msg: dict[str, Any]) -> None: # EngineCoreOutput, so we must purge them explicitly. self.output_processors[stage_id].abort_requests(all_ids_to_abort, internal=True) for req_id in all_ids_to_abort: + self._clear_pd_prefill_multimodal_output(req_id) self.request_states.pop(req_id, None) logger.info("[Orchestrator] Aborted request(s) %s", request_ids) @@ -1210,6 +1292,7 @@ async def _drain_pending_requests_on_fatal(self) -> None: "stage_id": self._fatal_error_stage_id, } ) + self._clear_pd_prefill_multimodal_output(req_id) self.request_states.pop(req_id, None) def _shutdown_stages(self) -> None: diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py index 37b0e538a13..1655c49b28b 100644 --- a/vllm_omni/engine/stage_engine_core_client.py +++ b/vllm_omni/engine/stage_engine_core_client.py @@ -121,8 +121,11 @@ def __init__( self.default_sampling_params = metadata.default_sampling_params self.custom_process_input_func = metadata.custom_process_input_func self.model_stage = metadata.model_stage + self.is_prefill_only = metadata.is_prefill_only + self.is_decode_only = metadata.is_decode_only self.engine_outputs: Any = None + self.pd_prefill_multimodal_by_req: dict[str, dict[str, Any]] = {} self._proc = proc self.client_addresses = dict(client_addresses or {}) self._omni_kv_config = getattr(getattr(vllm_config, "model_config", None), "omni_kv_config", None) @@ -356,6 +359,18 @@ def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: """Set engine outputs (called by orchestrator).""" self.engine_outputs = engine_outputs + def set_pd_prefill_multimodal_output(self, req_id: str, multimodal_output: dict[str, Any] | None) -> None: + if multimodal_output is None: + self.pd_prefill_multimodal_by_req.pop(req_id, None) + else: + self.pd_prefill_multimodal_by_req[req_id] = multimodal_output + + def get_pd_prefill_multimodal_output(self, req_id: str) -> dict[str, Any] | None: + return self.pd_prefill_multimodal_by_req.get(req_id) + + def clear_pd_prefill_multimodal_output(self, req_id: str) -> None: + self.pd_prefill_multimodal_by_req.pop(req_id, None) + def process_engine_inputs( self, stage_list: list[Any], diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index 277edb5eb4d..47f1115a6e1 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -257,6 +257,8 @@ class StageMetadata: custom_process_input_func: Callable | None model_stage: str | None runtime_cfg: Any + is_prefill_only: bool = False + is_decode_only: bool = False prompt_expand_func: Callable | None = None cfg_kv_collect_func: Callable | None = None @@ -335,6 +337,8 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: custom_process_input_func=custom_process_input_func, model_stage=None, runtime_cfg=runtime_cfg, + is_prefill_only=bool(getattr(stage_config, "is_prefill_only", False)), + is_decode_only=bool(getattr(stage_config, "is_decode_only", False)), cfg_kv_collect_func=cfg_kv_collect_func, ) @@ -356,6 +360,8 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: custom_process_input_func=custom_process_input_func, model_stage=model_stage, runtime_cfg=runtime_cfg, + is_prefill_only=bool(getattr(stage_config, "is_prefill_only", False)), + is_decode_only=bool(getattr(stage_config, "is_decode_only", False)), prompt_expand_func=prompt_expand_func, ) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 8e74901ddd8..6d352db7212 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -181,6 +181,13 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="Override the deploy YAML's ``async_chunk:`` bool. Unset leaves the YAML value in force.", ) + omni_config_group.add_argument( + "--enable-pd-disaggregation", + action=argparse.BooleanOptionalAction, + default=None, + help="Override the deploy YAML's ``pd_disaggregation.enabled:`` bool. " + "Use with --stage-overrides to tune the generated prefill/decode stages from CLI.", + ) omni_config_group.add_argument( "--stage-id", type=int, diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 28b969ff7cd..e34cabf8e30 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -841,6 +841,10 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch # compatible with old shape [1,S,D] rem_tail = trailing_text_hidden.squeeze(0) if rem_tail.shape[0] > 0: + # Prefill bootstrap already consumed the first assistant text token. + # Keep decode aligned to the remaining text boundary. + if rem_tail.ndim == 2 and rem_tail.shape[0] > 1: + rem_tail = rem_tail[1:, :] update_dict.setdefault("hidden_states", {})["trailing_text"] = rem_tail.detach() # Also persist projected tts_pad for decode fallback if needed if isinstance(tts_pad_thinker, torch.Tensor): diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index f7da44067e7..f76cbc606aa 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -117,7 +117,7 @@ def _get_prefill_stage(stage_list: list[Any], source_stage_id: int) -> Any | Non if not getattr(source_stage, "is_decode_only", False): return None prev_stage = stage_list[source_stage_id - 1] - if getattr(prev_stage, "is_prefill_only", False) and prev_stage.engine_outputs is not None: + if getattr(prev_stage, "is_prefill_only", False): return prev_stage return None @@ -167,14 +167,11 @@ def _merge_pd_embeddings( return merged_emb, merged_hid -def _get_prefill_multimodal_output(prefill_stage: Any, output_index: int) -> dict[str, Any] | None: - """Return multimodal_output dict from the PD prefill stage for a given batch index.""" - try: - prefill_eos = prefill_stage.engine_outputs - prefill_eo = prefill_eos[min(output_index, len(prefill_eos) - 1)] - return prefill_eo.outputs[0].multimodal_output - except Exception: - return None +def _get_prefill_multimodal_output(prefill_stage: Any, req_id: str) -> dict[str, Any] | None: + get_fn = getattr(prefill_stage, "get_pd_prefill_multimodal_output", None) + if callable(get_fn): + return get_fn(req_id) + return None def _resolve_tts_token_embedding( @@ -479,7 +476,7 @@ def thinker2talker( prefill_mm: dict[str, Any] | None = None if prefill_stage is not None: - prefill_mm = _get_prefill_multimodal_output(prefill_stage, i) + prefill_mm = _get_prefill_multimodal_output(prefill_stage, req_id) if prefill_mm is not None: expected_total = len(prompt_token_ids) + len(output_ids)