From 9a7114678db2cd4c26e2a6425aa05a16b5e5270f Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Wed, 11 Mar 2026 11:19:54 +0800 Subject: [PATCH 01/14] Support FunAudioChat S2S in vLLM-Omni Signed-off-by: ramos.ma Signed-off-by: mayufeng --- .gitignore | 1 + docs/models/supported_models.md | 1 + pyproject.toml | 1 + .../entrypoints/test_funaudiochat_contrib.py | 38 ++ .../models/test_funaudiochat_code2wav.py | 96 +++ .../models/test_funaudiochat_native.py | 495 ++++++++++++++ .../test_funaudiochat.py | 88 +++ tests/test_outputs.py | 77 ++- tests/worker/test_omni_gpu_model_runner.py | 222 +++++++ vllm_omni/engine/arg_utils.py | 9 +- vllm_omni/engine/input_processor.py | 7 +- vllm_omni/entrypoints/omni.py | 13 +- vllm_omni/entrypoints/omni_stage.py | 5 + vllm_omni/model_executor/models/__init__.py | 3 + .../models/cosyvoice3/cosyvoice3_code2wav.py | 1 - .../models/cosyvoice3/hf_config/config.json | 3 + .../models/cosyvoice3/hf_config_utils.py | 17 + .../models/funaudiochat/__init__.py | 4 + .../models/funaudiochat/common.py | 277 ++++++++ .../models/funaudiochat/funaudiochat.py | 622 ++++++++++++++++++ .../funaudiochat/funaudiochat_code2wav.py | 453 +++++++++++++ vllm_omni/model_executor/models/registry.py | 10 + .../stage_configs/funaudiochat_s2s.yaml | 66 ++ .../stage_input_processors/funaudiochat.py | 78 +++ vllm_omni/outputs.py | 39 +- .../transformers_utils/configs/__init__.py | 5 + .../configs/funaudiochat.py | 119 ++++ vllm_omni/worker/gpu_ar_model_runner.py | 103 ++- .../worker/gpu_generation_model_runner.py | 14 +- vllm_omni/worker/gpu_model_runner.py | 210 +++++- 30 files changed, 2981 insertions(+), 96 deletions(-) create mode 100644 tests/entrypoints/test_funaudiochat_contrib.py create mode 100644 tests/model_executor/models/test_funaudiochat_code2wav.py create mode 100644 tests/model_executor/models/test_funaudiochat_native.py create mode 100644 tests/model_executor/stage_input_processors/test_funaudiochat.py create mode 100644 vllm_omni/model_executor/models/cosyvoice3/hf_config/config.json create mode 100644 vllm_omni/model_executor/models/cosyvoice3/hf_config_utils.py create mode 100644 vllm_omni/model_executor/models/funaudiochat/__init__.py create mode 100644 vllm_omni/model_executor/models/funaudiochat/common.py create mode 100644 vllm_omni/model_executor/models/funaudiochat/funaudiochat.py create mode 100644 vllm_omni/model_executor/models/funaudiochat/funaudiochat_code2wav.py create mode 100644 vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml create mode 100644 vllm_omni/model_executor/stage_input_processors/funaudiochat.py create mode 100644 vllm_omni/transformers_utils/configs/funaudiochat.py diff --git a/.gitignore b/.gitignore index 28d56e0f6f0..1706cfd03aa 100644 --- a/.gitignore +++ b/.gitignore @@ -223,6 +223,7 @@ datasets/ *.csv *.json !apps/ComfyUI-vLLM-Omni/example_workflows/*.json +!vllm_omni/model_executor/models/cosyvoice3/hf_config/config.json *.jsonl *.parquet diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 4c709425e66..3fe0021f936 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -38,6 +38,7 @@ th { |`LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` | |`StableDiffusion3Pipeline` | Stable-Diffusion-3 | `stabilityai/stable-diffusion-3.5-medium` | |`CosyVoice3Model` | CosyVoice3 | `FunAudioLLM/Fun-CosyVoice3-0.5B-2512` | +|`FunAudioChatForConditionalGeneration` | Fun-Audio-Chat-8B | `FunAudioLLM/Fun-Audio-Chat-8B` | |`MammothModa2ForConditionalGeneration` | MammothModa2-Preview | `bytedance-research/MammothModa2-Preview` | |`Flux2KleinPipeline` | FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B`, `black-forest-labs/FLUX.2-klein-9B` | |`FluxPipeline` | FLUX.1-dev | `black-forest-labs/FLUX.1-dev` | diff --git a/pyproject.toml b/pyproject.toml index 5bcfab7d2d6..11ef133c2ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ include = ["vllm_omni*"] [tool.setuptools.package-data] "vllm_omni" = ["_version.py", "py.typed"] +"vllm_omni.model_executor.models.cosyvoice3" = ["hf_config/*.json"] "vllm_omni.model_executor.stage_configs" = ["*.yaml"] [tool.setuptools_scm] diff --git a/tests/entrypoints/test_funaudiochat_contrib.py b/tests/entrypoints/test_funaudiochat_contrib.py new file mode 100644 index 00000000000..7ca70f0eebf --- /dev/null +++ b/tests/entrypoints/test_funaudiochat_contrib.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from vllm_omni.engine.arg_utils import _resolve_bundled_hf_config_path +from vllm_omni.entrypoints.omni import OmniBase + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_resolve_bundled_hf_config_path_uses_cosyvoice3_bundle_by_default(): + resolved = _resolve_bundled_hf_config_path("FunAudioChatCosyVoice3Code2Wav", None) + + assert resolved is not None + assert resolved.endswith("vllm_omni/model_executor/models/cosyvoice3/hf_config") + assert (Path(resolved) / "config.json").is_file() + + +def test_resolve_bundled_hf_config_path_preserves_explicit_override(): + resolved = _resolve_bundled_hf_config_path("FunAudioChatCosyVoice3Code2Wav", "/tmp/custom-hf-config") + + assert resolved == "/tmp/custom-hf-config" + + +def test_get_stage_model_prefers_stage_override(): + stage = SimpleNamespace(engine_args=SimpleNamespace(model="stage-specific-model")) + + assert OmniBase._get_stage_model(stage, "fallback-model") == "stage-specific-model" + + +def test_get_stage_model_falls_back_when_stage_override_missing(): + stage = SimpleNamespace(engine_args=SimpleNamespace()) + + assert OmniBase._get_stage_model(stage, "fallback-model") == "fallback-model" diff --git a/tests/model_executor/models/test_funaudiochat_code2wav.py b/tests/model_executor/models/test_funaudiochat_code2wav.py new file mode 100644 index 00000000000..b4385d435eb --- /dev/null +++ b/tests/model_executor/models/test_funaudiochat_code2wav.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import torch + +from vllm_omni.model_executor.models.funaudiochat.funaudiochat_code2wav import ( + FunAudioChatCosyVoice3Code2Wav, +) + + +def test_split_tokens_like_official_keeps_short_inputs_as_single_segment(): + token = torch.arange(100, dtype=torch.long) + + segments = FunAudioChatCosyVoice3Code2Wav._split_tokens_like_official(token) + + assert len(segments) == 1 + assert torch.equal(segments[0], token) + + +def test_split_tokens_like_official_rebalances_tiny_tail_segment(): + token = torch.arange(760, dtype=torch.long) + + segments = FunAudioChatCosyVoice3Code2Wav._split_tokens_like_official(token) + + assert [segment.numel() for segment in segments] == [380, 380] + assert torch.equal(torch.cat(segments, dim=0), token) + + +def _build_code2wav_stub() -> FunAudioChatCosyVoice3Code2Wav: + model = object.__new__(FunAudioChatCosyVoice3Code2Wav) + model.vllm_config = SimpleNamespace(device_config=SimpleNamespace(device=torch.device("cpu"))) + model._max_codec_token_id = 6560 + model._dummy_profile_token_len = 32 + model._logged_dummy_profile_cap = False + return model + + +def test_build_decode_tokens_keeps_real_input_ids_without_sampling_metadata(): + model = _build_code2wav_stub() + input_ids = torch.tensor([12, 34, 56], dtype=torch.long) + + token_batches, is_dummy_profile = model._build_decode_tokens(input_ids, sampling_metadata=None) + + assert len(token_batches) == 1 + assert token_batches[0].tolist() == [[12, 34, 56]] + assert is_dummy_profile is False + + +def test_build_decode_tokens_uses_prompt_token_ids_when_input_ids_are_empty(): + model = _build_code2wav_stub() + sampling_metadata = SimpleNamespace(prompt_token_ids=[1, 2, 3, 4]) + + token_batches, is_dummy_profile = model._build_decode_tokens( + torch.empty((0,), dtype=torch.long), + sampling_metadata, + ) + + assert len(token_batches) == 1 + assert token_batches[0].tolist() == [[1, 2, 3, 4]] + assert is_dummy_profile is False + + +def test_build_decode_tokens_treats_all_zero_missing_metadata_as_dummy_profile(): + model = _build_code2wav_stub() + input_ids = torch.zeros((64,), dtype=torch.long) + + token_batches, is_dummy_profile = model._build_decode_tokens(input_ids, sampling_metadata=None) + + assert len(token_batches) == 1 + assert token_batches[0].shape == (1, 32) + assert is_dummy_profile is True + + +def test_build_decode_tokens_no_longer_rejects_long_sequences_before_segmentation(): + model = _build_code2wav_stub() + input_ids = torch.arange(10235, dtype=torch.long) % 6000 + + token_batches, is_dummy_profile = model._build_decode_tokens(input_ids, sampling_metadata=None) + + assert len(token_batches) == 1 + assert token_batches[0].shape == (1, 10235) + assert is_dummy_profile is False + + +def test_build_decode_tokens_preserves_batched_prompt_token_ids_per_request(): + model = _build_code2wav_stub() + sampling_metadata = SimpleNamespace(prompt_token_ids=[[1, 2, 3], [4, 5]]) + + token_batches, is_dummy_profile = model._build_decode_tokens( + torch.empty((0,), dtype=torch.long), + sampling_metadata, + ) + + assert [token.tolist() for token in token_batches] == [[[1, 2, 3]], [[4, 5]]] + assert is_dummy_profile is False diff --git a/tests/model_executor/models/test_funaudiochat_native.py b/tests/model_executor/models/test_funaudiochat_native.py new file mode 100644 index 00000000000..09786a4079c --- /dev/null +++ b/tests/model_executor/models/test_funaudiochat_native.py @@ -0,0 +1,495 @@ +from types import SimpleNamespace + +import torch +from transformers.generation.logits_process import ( + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopPLogitsWarper, +) +from transformers.modeling_outputs import BaseModelOutput + +import vllm_omni.model_executor.models.funaudiochat.funaudiochat as fac_mod +from vllm_omni.model_executor.models.funaudiochat.funaudiochat import ( + DEFAULT_SP_GEN_KWARGS, + FunAudioChatForConditionalGeneration, +) + + +def _make_model_stub( + *, + audio_bos_id: int = 42, + audio_eos_id: int = 99, + group_size: int = 5, + hidden_size: int = 4, +): + model = object.__new__(FunAudioChatForConditionalGeneration) + model.config = SimpleNamespace( + audio_config=SimpleNamespace(group_size=group_size, eos_token_id=audio_eos_id), + text_config=SimpleNamespace(audio_bos_index=audio_bos_id, audio_eos_index=audio_eos_id), + ) + model.sp_gen_kwargs = DEFAULT_SP_GEN_KWARGS.copy() + model._batch_preprocess_in_progress = False + model._batch_req_infos = [] + model._batch_sidecar_results = [] + model._postprocess_cursor = 0 + model._logged_stage0_backend = True + model.get_language_model = lambda: SimpleNamespace( + embed_input_ids=lambda input_ids: torch.zeros( + (input_ids.reshape(-1).numel(), hidden_size), + dtype=torch.float32, + device=input_ids.device, + ) + ) + model.audio_tower = lambda audio_ids: BaseModelOutput( + last_hidden_state=torch.full( + (audio_ids.shape[0], 1, hidden_size), + 2.0, + dtype=torch.float32, + device=audio_ids.device, + ) + ) + model._get_stage0_backend = lambda: "TEST" + return model + + +def test_default_sp_gen_kwargs_match_official_defaults(): + assert DEFAULT_SP_GEN_KWARGS == { + "text_greedy": True, + "only_crq_sampling": True, + "disable_speech": False, + "force_text_abos": True, + } + + +def test_pooler_output_buffer_only_snapshots_incremental_audio_groups(): + assert FunAudioChatForConditionalGeneration.pooler_output_buffer_keys == ("audio_token_ids",) + + +def test_build_crq_sampling_config_matches_official_sampling_defaults(): + model = _make_model_stub() + sampling_metadata = type( + "SamplingMetadataStub", + (), + { + "repetition_penalties": torch.tensor([1.2]), + "temperature": torch.tensor([0.8]), + "top_p": torch.tensor([0.9]), + "top_k": torch.tensor([0]), + }, + )() + + processors, do_sample = model._build_crq_sampling_config( + sampling_metadata=sampling_metadata, + req_index=0, + ) + + assert do_sample is True + assert any(isinstance(p, RepetitionPenaltyLogitsProcessor) for p in processors) + assert any(isinstance(p, TemperatureLogitsWarper) for p in processors) + assert any(isinstance(p, TopPLogitsWarper) for p in processors) + + +def test_build_crq_sampling_config_is_empty_for_greedy_without_penalties(): + model = _make_model_stub() + model.sp_gen_kwargs["text_greedy"] = False + sampling_metadata = type( + "SamplingMetadataStub", + (), + { + "repetition_penalties": torch.tensor([1.0]), + "temperature": None, + "top_p": None, + "top_k": None, + }, + )() + + processors, do_sample = model._build_crq_sampling_config( + sampling_metadata=sampling_metadata, + req_index=0, + ) + + assert do_sample is False + assert len(processors) == 0 + + +def test_build_crq_sampling_config_restores_official_audio_sampling_when_text_path_is_greedy(): + model = _make_model_stub() + model.sp_gen_kwargs["text_greedy"] = True + sampling_metadata = type( + "SamplingMetadataStub", + (), + { + "repetition_penalties": torch.tensor([1.2]), + "temperature": torch.tensor([0.0]), + "top_p": torch.tensor([1.0]), + "top_k": torch.tensor([-1]), + }, + )() + + processors, do_sample = model._build_crq_sampling_config( + sampling_metadata=sampling_metadata, + req_index=0, + ) + + assert do_sample is True + assert len(processors) == 3 + assert any(isinstance(p, RepetitionPenaltyLogitsProcessor) for p in processors) + assert any(isinstance(p, TemperatureLogitsWarper) for p in processors) + assert any(isinstance(p, TopPLogitsWarper) for p in processors) + + +def test_resolve_text_seq_len_prefill_accumulates_prompt_tokens(): + assert FunAudioChatForConditionalGeneration._resolve_text_seq_len(None, 5) == (5, 5) + assert FunAudioChatForConditionalGeneration._resolve_text_seq_len(5, 3) == (8, 8) + + +def test_resolve_text_seq_len_decode_advances_for_next_step(): + assert FunAudioChatForConditionalGeneration._resolve_text_seq_len(8, 1) == (8, 9) + assert FunAudioChatForConditionalGeneration._resolve_text_seq_len(None, 1) == (1, 2) + + +def test_resolve_next_speech_state_stays_text_only_until_audio_bos_is_sampled(): + final_token, next_speech_active, next_force_pending = ( + FunAudioChatForConditionalGeneration._resolve_next_speech_state( + sampled_token_id=7, + generate_speech=False, + finish_speech=False, + force_audio_bos_pending=False, + audio_bos_id=42, + audio_eos_id=99, + ) + ) + + assert final_token == 7 + assert next_speech_active is False + assert next_force_pending is False + + +def test_resolve_next_speech_state_arms_speech_after_audio_bos_is_sampled(): + final_token, next_speech_active, next_force_pending = ( + FunAudioChatForConditionalGeneration._resolve_next_speech_state( + sampled_token_id=42, + generate_speech=False, + finish_speech=False, + force_audio_bos_pending=False, + audio_bos_id=42, + audio_eos_id=99, + ) + ) + + assert final_token == 42 + assert next_speech_active is True + assert next_force_pending is False + + +def test_resolve_next_speech_state_force_text_abos_overrides_sampled_token(): + final_token, next_speech_active, next_force_pending = ( + FunAudioChatForConditionalGeneration._resolve_next_speech_state( + sampled_token_id=7, + generate_speech=False, + finish_speech=False, + force_audio_bos_pending=True, + audio_bos_id=42, + audio_eos_id=99, + ) + ) + + assert final_token == 42 + assert next_speech_active is True + assert next_force_pending is False + + +def test_resolve_next_speech_state_finish_speech_overrides_final_token_to_audio_eos(): + final_token, next_speech_active, next_force_pending = ( + FunAudioChatForConditionalGeneration._resolve_next_speech_state( + sampled_token_id=7, + generate_speech=True, + finish_speech=True, + force_audio_bos_pending=False, + audio_bos_id=42, + audio_eos_id=99, + ) + ) + + assert final_token == 99 + assert next_speech_active is False + assert next_force_pending is False + + +def test_postprocess_sampled_tokens_updates_buffer_from_final_sampled_token(): + model = _make_model_stub() + sampled_token_ids = torch.tensor([42], dtype=torch.long) + model_intermediate_buffer = { + "req0": { + fac_mod._GENERATE_SPEECH_KEY: False, + fac_mod._FORCE_AUDIO_BOS_KEY: False, + fac_mod._FINISH_SPEECH_KEY: False, + fac_mod._RAW_TEXT_TOKEN_ID_KEY: 42, + } + } + + updated = model.postprocess_sampled_tokens( + sampled_token_ids=sampled_token_ids, + req_ids=["req0"], + req_id_to_index={"req0": 0}, + model_intermediate_buffer=model_intermediate_buffer, + ) + + assert updated.tolist() == [42] + assert model_intermediate_buffer["req0"][fac_mod._GENERATE_SPEECH_KEY] is True + assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False + assert fac_mod._FINISH_SPEECH_KEY not in model_intermediate_buffer["req0"] + + +def test_postprocess_sampled_tokens_force_text_abos_overrides_sampled_token(): + model = _make_model_stub() + sampled_token_ids = torch.tensor([7], dtype=torch.long) + model_intermediate_buffer = { + "req0": { + fac_mod._GENERATE_SPEECH_KEY: False, + fac_mod._FORCE_AUDIO_BOS_KEY: True, + fac_mod._FINISH_SPEECH_KEY: False, + fac_mod._RAW_TEXT_TOKEN_ID_KEY: 7, + } + } + + updated = model.postprocess_sampled_tokens( + sampled_token_ids=sampled_token_ids, + req_ids=["req0"], + req_id_to_index={"req0": 0}, + model_intermediate_buffer=model_intermediate_buffer, + ) + + assert updated.tolist() == [42] + assert model_intermediate_buffer["req0"][fac_mod._GENERATE_SPEECH_KEY] is True + assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False + + +def test_postprocess_sampled_tokens_uses_raw_argmax_when_text_greedy_is_enabled(): + model = _make_model_stub() + sampled_token_ids = torch.tensor([7], dtype=torch.long) + model_intermediate_buffer = { + "req0": { + fac_mod._GENERATE_SPEECH_KEY: False, + fac_mod._FORCE_AUDIO_BOS_KEY: False, + fac_mod._FINISH_SPEECH_KEY: False, + fac_mod._RAW_TEXT_TOKEN_ID_KEY: 42, + } + } + + updated = model.postprocess_sampled_tokens( + sampled_token_ids=sampled_token_ids, + req_ids=["req0"], + req_id_to_index={"req0": 0}, + model_intermediate_buffer=model_intermediate_buffer, + ) + + assert updated.tolist() == [42] + assert model_intermediate_buffer["req0"][fac_mod._GENERATE_SPEECH_KEY] is True + assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False + + +def test_postprocess_sampled_tokens_respects_sampled_token_when_text_greedy_disabled(): + model = _make_model_stub() + model.sp_gen_kwargs["text_greedy"] = False + sampled_token_ids = torch.tensor([7], dtype=torch.long) + model_intermediate_buffer = { + "req0": { + fac_mod._GENERATE_SPEECH_KEY: False, + fac_mod._FORCE_AUDIO_BOS_KEY: False, + fac_mod._FINISH_SPEECH_KEY: False, + fac_mod._RAW_TEXT_TOKEN_ID_KEY: 42, + } + } + + updated = model.postprocess_sampled_tokens( + sampled_token_ids=sampled_token_ids, + req_ids=["req0"], + req_id_to_index={"req0": 0}, + model_intermediate_buffer=model_intermediate_buffer, + ) + + assert updated.tolist() == [7] + assert model_intermediate_buffer["req0"][fac_mod._GENERATE_SPEECH_KEY] is False + assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False + + +def test_postprocess_sampled_tokens_overwrites_emitted_token_to_audio_eos_on_finish(): + model = _make_model_stub() + sampled_token_ids = torch.tensor([7], dtype=torch.long) + model_intermediate_buffer = { + "req0": { + fac_mod._GENERATE_SPEECH_KEY: True, + fac_mod._FORCE_AUDIO_BOS_KEY: False, + fac_mod._FINISH_SPEECH_KEY: True, + fac_mod._RAW_TEXT_TOKEN_ID_KEY: 7, + } + } + + updated = model.postprocess_sampled_tokens( + sampled_token_ids=sampled_token_ids, + req_ids=["req0"], + req_id_to_index={"req0": 0}, + model_intermediate_buffer=model_intermediate_buffer, + ) + + assert updated.tolist() == [99] + assert model_intermediate_buffer["req0"][fac_mod._GENERATE_SPEECH_KEY] is False + assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False + assert fac_mod._FINISH_SPEECH_KEY not in model_intermediate_buffer["req0"] + + +def test_postprocess_sampled_tokens_noops_for_spec_decode_shapes(): + model = _make_model_stub() + sampled_token_ids = torch.tensor([[7, 8]], dtype=torch.long) + model_intermediate_buffer = { + "req0": { + fac_mod._GENERATE_SPEECH_KEY: False, + fac_mod._FORCE_AUDIO_BOS_KEY: True, + fac_mod._FINISH_SPEECH_KEY: False, + fac_mod._RAW_TEXT_TOKEN_ID_KEY: 42, + } + } + + updated = model.postprocess_sampled_tokens( + sampled_token_ids=sampled_token_ids, + req_ids=["req0"], + req_id_to_index={"req0": 0}, + model_intermediate_buffer=model_intermediate_buffer, + ) + + assert torch.equal(updated, sampled_token_ids) + assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is True + + +def test_chunked_prefill_preprocess_keeps_speech_inactive(): + model = _make_model_stub() + + _, _, first_update = model.preprocess( + input_ids=torch.tensor([1, 2, 3], dtype=torch.long), + input_embeds=None, + ) + _, _, second_update = model.preprocess( + input_ids=torch.tensor([4, 5], dtype=torch.long), + input_embeds=None, + **first_update, + ) + + assert first_update[fac_mod._GENERATE_SPEECH_KEY] is False + assert first_update[fac_mod._FORCE_AUDIO_BOS_KEY] is True + assert second_update[fac_mod._GENERATE_SPEECH_KEY] is False + assert second_update[fac_mod._FORCE_AUDIO_BOS_KEY] is True + assert torch.equal(first_update["audio_token_ids"], torch.full((1, 5), -1, dtype=torch.long)) + assert torch.equal(second_update["audio_token_ids"], torch.full((1, 5), -1, dtype=torch.long)) + + +def test_preprocess_single_token_text_decode_keeps_input_id_path(): + model = _make_model_stub() + + _, req_embeds, _ = model.preprocess( + input_ids=torch.tensor([7], dtype=torch.long), + input_embeds=None, + ) + + assert req_embeds is None + + +def test_preprocess_first_speech_step_without_codec_history_keeps_input_id_path(): + model = _make_model_stub() + + _, req_embeds, _ = model.preprocess( + input_ids=torch.tensor([42], dtype=torch.long), + input_embeds=None, + **{ + fac_mod._GENERATE_SPEECH_KEY: True, + fac_mod._SPEECH_IDS_KEY: torch.empty((1, 0), dtype=torch.long), + }, + ) + + assert req_embeds is None + + +def test_preprocess_active_speech_with_codec_history_blends_audio_features(): + model = _make_model_stub(hidden_size=4) + + _, req_embeds, _ = model.preprocess( + input_ids=torch.tensor([42], dtype=torch.long), + input_embeds=None, + **{ + fac_mod._GENERATE_SPEECH_KEY: True, + fac_mod._SPEECH_IDS_KEY: torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long), + }, + ) + + assert torch.equal(req_embeds, torch.full((1, 4), 1.0)) + + +def test_run_audio_sidecar_decode_warmup_updates_cache_only(): + model = _make_model_stub(hidden_size=4) + + class AudioInvertTowerStub: + def __init__(self): + self.crq_audio_embeds = None + self.crq_past_key_values = None + self.crq_do_sample = None + self.crq_logits_processor = None + self.crq_speech_ids = None + + def crq_generate_forward(self, *, inputs_embeds, return_dict=True): + del return_dict + self.last_inputs_embeds = inputs_embeds + self.crq_audio_embeds = torch.full((1, 4), 5.0, dtype=torch.float32, device=inputs_embeds.device) + self.crq_past_key_values = (torch.full((1, 1), 7.0, dtype=torch.float32, device=inputs_embeds.device),) + + model.audio_invert_tower = AudioInvertTowerStub() + + warmup_state = model._run_audio_sidecar_decode_warmup( + hidden_state=torch.zeros(4, dtype=torch.float32), + current_input_token_id=7, + speech_ids=torch.empty((1, 0), dtype=torch.long), + cached_audio_embeds=None, + cached_past_key_values=None, + logits_processor=[], + do_sample=True, + ) + + assert list(model.audio_invert_tower.last_inputs_embeds.shape) == [1, 1, 4] + assert torch.equal(warmup_state[fac_mod._CRQ_AUDIO_EMBEDS_KEY], torch.full((1, 4), 5.0)) + assert torch.equal(warmup_state[fac_mod._CRQ_PAST_KEY_VALUES_KEY][0], torch.full((1, 1), 7.0)) + + +def test_postprocess_prefill_warmup_updates_cache_without_emitting_audio(): + model = _make_model_stub(hidden_size=4) + model._batch_sidecar_results = [ + { + fac_mod._AUDIO_TOKEN_IDS_KEY: torch.full((1, 5), -1, dtype=torch.long), + fac_mod._CRQ_AUDIO_EMBEDS_KEY: None, + fac_mod._CRQ_PAST_KEY_VALUES_KEY: None, + fac_mod._FORCE_AUDIO_BOS_KEY: True, + fac_mod._FINISH_SPEECH_KEY: False, + fac_mod._GENERATE_SPEECH_KEY: False, + fac_mod._RAW_TEXT_TOKEN_ID_KEY: 1, + fac_mod._SPEECH_IDS_KEY: torch.empty((1, 0), dtype=torch.long), + "_run_prefill_crq_warmup": True, + "_prefill_input_ids": torch.tensor([1, 2, 3], dtype=torch.long), + "_prefill_crq_logits_processor": [], + "_prefill_crq_do_sample": False, + "audio_token_ids": torch.full((1, 5), -1, dtype=torch.long), + } + ] + model._postprocess_cursor = 0 + + def _prefill_warmup(**kwargs): + del kwargs + return { + fac_mod._CRQ_AUDIO_EMBEDS_KEY: torch.full((1, 4), 9.0), + fac_mod._CRQ_PAST_KEY_VALUES_KEY: (torch.full((1, 1), 3.0),), + } + + model._run_audio_sidecar_prefill_warmup = _prefill_warmup + + output = model.postprocess(torch.zeros((3, 4), dtype=torch.float32)) + + assert torch.equal(output["audio_token_ids"], torch.full((1, 5), -1, dtype=torch.long)) + assert torch.equal(output[fac_mod._CRQ_AUDIO_EMBEDS_KEY], torch.full((1, 4), 9.0)) + assert torch.equal(output[fac_mod._CRQ_PAST_KEY_VALUES_KEY][0], torch.full((1, 1), 3.0)) diff --git a/tests/model_executor/stage_input_processors/test_funaudiochat.py b/tests/model_executor/stage_input_processors/test_funaudiochat.py new file mode 100644 index 00000000000..5dc0fd4a0c2 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_funaudiochat.py @@ -0,0 +1,88 @@ +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors.funaudiochat import funaudiochat2code2wav + + +def _stage_list(audio_token_ids=None, speech_ids=None): + output = SimpleNamespace(multimodal_output={}) + if audio_token_ids is not None: + output.multimodal_output["audio_token_ids"] = audio_token_ids + if speech_ids is not None: + output.multimodal_output["speech_ids"] = speech_ids + stage_output = SimpleNamespace(outputs=[output]) + stage = SimpleNamespace(engine_outputs=[stage_output]) + return [stage] + + +def test_filters_invalid_audio_tokens_for_code2wav(): + stage_inputs = _stage_list(torch.tensor([-1, 0, 17, 6560, 6561, 7000], dtype=torch.long)) + + prompts = funaudiochat2code2wav(stage_inputs, engine_input_source=[0]) + + assert len(prompts) == 1 + assert prompts[0]["prompt_token_ids"] == [0, 17, 6560] + assert prompts[0]["multi_modal_data"] is None + assert prompts[0]["mm_processor_kwargs"] is None + + +def test_accepts_list_audio_tokens(): + stage_inputs = _stage_list([5, 12, 6559, 6561]) + + prompts = funaudiochat2code2wav(stage_inputs, engine_input_source=[0]) + + assert prompts[0]["prompt_token_ids"] == [5, 12, 6559] + + +def test_accepts_step_aligned_audio_token_rows(): + stage_inputs = _stage_list( + torch.tensor( + [ + [0, 0, 0, 0, 0], + [11, 12, 13, 14, 15], + [16, 17, 18, 19, 20], + ], + dtype=torch.long, + ) + ) + + prompts = funaudiochat2code2wav(stage_inputs, engine_input_source=[0]) + + assert prompts[0]["prompt_token_ids"] == [0, 0, 0, 0, 0, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + + +def test_drops_fully_negative_step_aligned_rows(): + stage_inputs = _stage_list( + torch.tensor( + [ + [-1, -1, -1, -1, -1], + [0, 0, 0, 0, 0], + [21, 22, 23, 24, 25], + ], + dtype=torch.long, + ) + ) + + prompts = funaudiochat2code2wav(stage_inputs, engine_input_source=[0]) + + assert prompts[0]["prompt_token_ids"] == [0, 0, 0, 0, 0, 21, 22, 23, 24, 25] + + +def test_prefers_incremental_audio_token_ids_over_cumulative_speech_ids(): + stage_inputs = _stage_list( + audio_token_ids=torch.tensor([[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], dtype=torch.long), + speech_ids=torch.tensor([101, 102, 103], dtype=torch.long), + ) + + prompts = funaudiochat2code2wav(stage_inputs, engine_input_source=[0]) + + assert prompts[0]["prompt_token_ids"] == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + + +def test_raises_when_audio_token_ids_are_missing(): + stage_inputs = _stage_list(None) + + with pytest.raises(ValueError, match="speech_ids|audio_token_ids"): + funaudiochat2code2wav(stage_inputs, engine_input_source=[0]) diff --git a/tests/test_outputs.py b/tests/test_outputs.py index 8da19a4980e..f2758a94186 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -14,6 +14,22 @@ class TestOmniRequestOutput: """Tests for OmniRequestOutput class.""" + def _make_pipeline_request_output(self, mocker: MockerFixture, **overrides): + mock_request_output = mocker.Mock() + mock_request_output.request_id = "pipeline-123" + mock_request_output.prompt_token_ids = None + mock_request_output.outputs = [] + mock_request_output.encoder_prompt_token_ids = None + mock_request_output.prompt_logprobs = None + mock_request_output.num_cached_tokens = None + mock_request_output.kv_transfer_params = None + mock_request_output.multimodal_output = None + + for key, value in overrides.items(): + setattr(mock_request_output, key, value) + + return mock_request_output + def test_from_diffusion(self): """Test creating output from diffusion model.""" images = [Image.new("RGB", (64, 64), color="red")] @@ -32,15 +48,13 @@ def test_from_diffusion(self): def test_from_pipeline(self, mocker: MockerFixture): """Test creating output from pipeline stage.""" - mock_request_output = mocker.Mock() - mock_request_output.request_id = "pipeline-123" - mock_request_output.prompt_token_ids = [1, 2, 3] - mock_request_output.outputs = [mocker.Mock()] - mock_request_output.encoder_prompt_token_ids = None - mock_request_output.prompt_logprobs = None - mock_request_output.num_cached_tokens = 10 - mock_request_output.kv_transfer_params = None - mock_request_output.multimodal_output = {"image": mocker.Mock()} + mock_request_output = self._make_pipeline_request_output( + mocker, + prompt_token_ids=[1, 2, 3], + outputs=[mocker.Mock()], + num_cached_tokens=10, + multimodal_output={"image": mocker.Mock()}, + ) output = OmniRequestOutput.from_pipeline( stage_id=0, @@ -55,8 +69,7 @@ def test_from_pipeline(self, mocker: MockerFixture): def test_prompt_token_ids_property(self, mocker: MockerFixture): """Test prompt_token_ids property for streaming compatibility.""" - mock_request_output = mocker.Mock() - mock_request_output.prompt_token_ids = [1, 2, 3, 4, 5] + mock_request_output = self._make_pipeline_request_output(mocker, prompt_token_ids=[1, 2, 3, 4, 5]) output = OmniRequestOutput.from_pipeline( stage_id=0, @@ -78,8 +91,7 @@ def test_prompt_token_ids_none_when_no_request_output(self): def test_outputs_property(self, mocker: MockerFixture): """Test outputs property for chat completion compatibility.""" mock_output = mocker.Mock() - mock_request_output = mocker.Mock() - mock_request_output.outputs = [mock_output] + mock_request_output = self._make_pipeline_request_output(mocker, outputs=[mock_output]) output = OmniRequestOutput.from_pipeline( stage_id=0, @@ -89,6 +101,32 @@ def test_outputs_property(self, mocker: MockerFixture): assert output.outputs == [mock_output] + def test_outputs_property_flattens_request_output_list(self, mocker: MockerFixture): + """Test outputs property when pipeline request_output is a list.""" + mock_output_a = mocker.Mock() + mock_output_b = mocker.Mock() + mock_request_output_a = self._make_pipeline_request_output( + mocker, + request_id="pipeline-123-a", + prompt_token_ids=[1, 2, 3], + outputs=[mock_output_a], + ) + mock_request_output_b = self._make_pipeline_request_output( + mocker, + request_id="pipeline-123-b", + prompt_token_ids=[4, 5, 6], + outputs=[mock_output_b], + ) + + output = OmniRequestOutput( + stage_id=0, + final_output_type="text", + request_output=[mock_request_output_a, mock_request_output_b], + ) + + assert output.outputs == [mock_output_a, mock_output_b] + assert output.prompt_token_ids == [1, 2, 3] + def test_outputs_empty_when_no_request_output(self): """Test outputs returns empty list when no request_output.""" output = OmniRequestOutput.from_diffusion( @@ -100,8 +138,7 @@ def test_outputs_empty_when_no_request_output(self): def test_encoder_prompt_token_ids_property(self, mocker: MockerFixture): """Test encoder_prompt_token_ids property.""" - mock_request_output = mocker.Mock() - mock_request_output.encoder_prompt_token_ids = [10, 20, 30] + mock_request_output = self._make_pipeline_request_output(mocker, encoder_prompt_token_ids=[10, 20, 30]) output = OmniRequestOutput.from_pipeline( stage_id=0, @@ -113,8 +150,7 @@ def test_encoder_prompt_token_ids_property(self, mocker: MockerFixture): def test_num_cached_tokens_property(self, mocker: MockerFixture): """Test num_cached_tokens property.""" - mock_request_output = mocker.Mock() - mock_request_output.num_cached_tokens = 42 + mock_request_output = self._make_pipeline_request_output(mocker, num_cached_tokens=42) output = OmniRequestOutput.from_pipeline( stage_id=0, @@ -126,11 +162,9 @@ def test_num_cached_tokens_property(self, mocker: MockerFixture): def test_multimodal_output_property(self, mocker: MockerFixture): """Test multimodal_output property.""" - mock_request_output = mocker.Mock() mock_audio = mocker.Mock() expected_output = {"audio": mock_audio} - mock_request_output.outputs = [] - mock_request_output.multimodal_output = expected_output + mock_request_output = self._make_pipeline_request_output(mocker, multimodal_output=expected_output) output = OmniRequestOutput.from_pipeline( stage_id=0, @@ -158,8 +192,7 @@ def test_to_dict_diffusion(self): def test_to_dict_pipeline(self, mocker: MockerFixture): """Test to_dict for pipeline output.""" - mock_request_output = mocker.Mock() - mock_request_output.request_id = "pipeline-123" + mock_request_output = self._make_pipeline_request_output(mocker) output = OmniRequestOutput.from_pipeline( stage_id=0, diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index b2d61931558..1379e1d1e5b 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -1,9 +1,12 @@ from contextlib import contextmanager from types import SimpleNamespace +import numpy as np import pytest import torch +from vllm.v1.outputs import SamplerOutput +from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -38,6 +41,52 @@ def __init__(self): # No real forward needed for these tests. +class ReplaceSampledTokensModel(torch.nn.Module): + """Returns a replacement sampled-token tensor from the post-sample hook.""" + + def __init__(self): + super().__init__() + self.observed_sampled_token_ids = None + + def postprocess_sampled_tokens(self, sampled_token_ids, req_ids, req_id_to_index, model_intermediate_buffer): + assert req_ids == ["r1", "r2"] + assert req_id_to_index == {"r1": 0, "r2": 1} + assert model_intermediate_buffer == {"r1": {"token": 1}, "r2": {"token": 2}} + self.observed_sampled_token_ids = sampled_token_ids.clone() + return sampled_token_ids + 10 + + +class OverlaySampledTokensModel(torch.nn.Module): + """Validates that post-sample hooks receive overlaid pending updates.""" + + def __init__(self): + super().__init__() + self.observed_buffer = None + self.pooler_output_buffer_keys = ("audio_token_ids",) + + def postprocess_sampled_tokens(self, sampled_token_ids, req_ids, req_id_to_index, model_intermediate_buffer): + del sampled_token_ids, req_ids, req_id_to_index + self.observed_buffer = model_intermediate_buffer + return None + + +class RawTokenPreprocessModel(torch.nn.Module): + """Tracks whether preprocess receives raw input ids without an embed slice.""" + + has_preprocess = True + requires_raw_input_tokens = True + + def __init__(self, hidden_size: int = 4): + super().__init__() + self.hidden_size = hidden_size + self.observed_input_embeds = [] + + def preprocess(self, input_ids, input_embeds, **info_dict): + self.observed_input_embeds.append(input_embeds) + req_embeds = input_ids.to(dtype=torch.float32).unsqueeze(-1).repeat(1, self.hidden_size) + return input_ids + 100, req_embeds, {"marker_seen": info_dict.get("marker")} + + class DummyTalkerMTP(torch.nn.Module): """A fake talker_mtp module for deterministic CPU testing.""" @@ -116,6 +165,83 @@ class _DummyVllmConfig: return runner +def _make_preprocess_runner(model, hidden_size=4): + runner = object.__new__(OmniGPUModelRunner) + runner.model = model + runner.model_config = SimpleNamespace(is_encoder_decoder=False) + runner.supports_mm_inputs = False + runner.enable_prompt_embeds = False + runner.uses_mrope = False + runner.uses_xdrope_dim = 0 + runner.positions = DummyBuffer(torch.arange(8, dtype=torch.int64)) + runner.input_ids = DummyBuffer(torch.tensor([1, 2, 3, 4], dtype=torch.int32)) + runner.inputs_embeds = DummyBuffer(torch.full((4, hidden_size), -1.0, dtype=torch.float32)) + runner.input_batch = SimpleNamespace( + req_ids=["r1"], + num_computed_tokens_cpu=np.array([0], dtype=np.int32), + ) + runner.requests = {"r1": SimpleNamespace(prompt_token_ids=[], mm_features=[])} + runner.model_intermediate_buffer = {"r1": {"marker": "r1"}} + runner.query_start_loc = SimpleNamespace(cpu=torch.tensor([0], dtype=torch.int32)) + runner.dtype = torch.float32 + runner.device = torch.device("cpu") + runner.vllm_config = SimpleNamespace(model_config=SimpleNamespace(async_chunk=False)) + runner._init_model_kwargs = lambda: {} + return runner + + +class StopAfterBookkeepingError(Exception): + pass + + +def _make_sample_tokens_runner(model): + runner = object.__new__(GPUARModelRunner) + runner.model = model + runner.speculative_config = None + runner.use_async_scheduling = False + runner.input_batch = SimpleNamespace( + req_ids=["r1", "r2"], + req_id_to_index={"r1": 0, "r2": 1}, + sampling_metadata=SimpleNamespace(no_penalties=True), + prev_sampled_token_ids=None, + num_tokens_no_spec=np.array([1, 1], dtype=np.int32), + token_ids_cpu=np.array([[1, 0, 0, 0], [2, 0, 0, 0]], dtype=np.int32), + vocab_size=32000, + ) + runner.model_intermediate_buffer = {"r1": {"token": 1}, "r2": {"token": 2}} + runner.requests = { + "r1": SimpleNamespace(output_token_ids=[1]), + "r2": SimpleNamespace(output_token_ids=[2]), + } + runner.execute_model_state = ( + SimpleNamespace(total_num_scheduled_tokens=2, num_scheduled_tokens={"r1": 1, "r2": 1}), + None, + None, + None, + torch.zeros((2, 4), dtype=torch.float32), + torch.zeros((2, 4), dtype=torch.float32), + None, + None, + None, + None, + None, + ) + runner._sample = lambda logits, spec_decode_metadata: SamplerOutput( + sampled_token_ids=torch.tensor([[1], [2]], dtype=torch.int32), + logprobs_tensors=None, + ) + runner.max_model_len = 4 + runner.query_start_loc = SimpleNamespace(cpu=torch.tensor([0, 1], dtype=torch.int32)) + runner._omni_num_scheduled_tokens_np = np.array([1, 1], dtype=np.int32) + runner.vllm_config = SimpleNamespace(model_config=SimpleNamespace(engine_output_type="omni")) + runner.model_config = SimpleNamespace(enable_return_routed_experts=False) + runner.supports_mm_inputs = False + runner.kv_connector_output = None + runner.eplb_step = lambda: None + runner.finalize_kv_connector = lambda: None + return runner + + def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): # Patch the module-level `set_forward_context` symbol used inside # OmniGPUModelRunner._talker_mtp_forward. @@ -250,3 +376,99 @@ def test_maybe_attach_mimo_audio_req_infos_no_req_state_returns_input(): # When no req_state, helper should be a no-op. assert result is req_infos + + +def test_sample_tokens_applies_postprocessed_tokens_before_bookkeeping(): + runner = _make_sample_tokens_runner(ReplaceSampledTokensModel()) + captured = {} + + def fake_bookkeeping( + self, + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + spec_decode_metadata, + ): + captured["sampled_token_ids"] = sampler_output.sampled_token_ids.clone() + raise StopAfterBookkeepingError + + runner._bookkeeping_sync = fake_bookkeeping.__get__(runner, type(runner)) + + with pytest.raises(StopAfterBookkeepingError): + GPUARModelRunner.sample_tokens(runner, grammar_output=None) + + assert torch.equal(runner.model.observed_sampled_token_ids, torch.tensor([[1], [2]], dtype=torch.int32)) + assert torch.equal(captured["sampled_token_ids"], torch.tensor([[11], [12]], dtype=torch.int32)) + + +def test_sample_tokens_passes_pending_updates_to_postprocess_without_committing_before_bookkeeping(): + runner = _make_sample_tokens_runner(OverlaySampledTokensModel()) + + def fake_collect(*args, **kwargs): + del args, kwargs + return {"r1": {"pending": 11}, "r2": {"pending": 22}} + + captured = {} + + def fake_bookkeeping( + self, + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + spec_decode_metadata, + ): + del scheduler_output, sampler_output, logits, hidden_states, num_scheduled_tokens, spec_decode_metadata + captured["buffer_before_bookkeeping"] = { + req_id: dict(info) for req_id, info in self.model_intermediate_buffer.items() + } + raise StopAfterBookkeepingError + + runner._collect_additional_information_updates = fake_collect + runner._bookkeeping_sync = fake_bookkeeping.__get__(runner, type(runner)) + + with pytest.raises(StopAfterBookkeepingError): + GPUARModelRunner.sample_tokens(runner, grammar_output=None) + + assert runner.model.observed_buffer == { + "r1": {"token": 1, "pending": 11}, + "r2": {"token": 2, "pending": 22}, + } + assert captured["buffer_before_bookkeeping"] == {"r1": {"token": 1}, "r2": {"token": 2}} + + +def test_preprocess_passes_none_input_embeds_for_raw_token_models(monkeypatch): + import vllm_omni.worker.gpu_model_runner as mod + + monkeypatch.setattr(mod, "get_pp_group", lambda: SimpleNamespace(is_first_rank=True)) + + runner = _make_preprocess_runner(RawTokenPreprocessModel(hidden_size=4), hidden_size=4) + scheduler_output = SimpleNamespace( + total_num_scheduled_tokens=2, + num_scheduled_tokens={"r1": 2}, + scheduled_encoder_inputs=None, + ) + + input_ids, inputs_embeds, *_ = OmniGPUModelRunner._preprocess( + runner, + scheduler_output, + num_input_tokens=2, + ) + + assert runner.model.observed_input_embeds == [None] + assert torch.equal(input_ids, torch.tensor([101, 102], dtype=torch.int32)) + assert inputs_embeds.data_ptr() == runner.inputs_embeds.gpu[:2].data_ptr() + assert torch.equal( + inputs_embeds, + torch.tensor( + [ + [1.0, 1.0, 1.0, 1.0], + [2.0, 2.0, 2.0, 2.0], + ], + dtype=torch.float32, + ), + ) + assert runner.model_intermediate_buffer["r1"]["marker_seen"] == "r1" diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index d733e7bc4bb..56a5b651cbe 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -7,6 +7,9 @@ from vllm.transformers_utils.gguf_utils import is_gguf from vllm_omni.config import OmniModelConfig +from vllm_omni.model_executor.models.cosyvoice3.hf_config_utils import ( + resolve_bundled_hf_config_path as _resolve_bundled_hf_config_path, +) from vllm_omni.plugins import load_omni_general_plugins logger = init_logger(__name__) @@ -113,6 +116,7 @@ def create_model_config(self) -> OmniModelConfig: # register omni models to avoid model not found error self._ensure_omni_models_registered() + hf_config_path = _resolve_bundled_hf_config_path(self.model_arch, self.hf_config_path) # Keep compatibility when async args are constructed from partial payloads. limit_mm_per_prompt = getattr(self, "limit_mm_per_prompt", {}) @@ -143,7 +147,7 @@ def create_model_config(self) -> OmniModelConfig: # Base ModelConfig fields (matching vLLM's EngineArgs.create_model_config) model=self.model, model_weights=self.model_weights, - hf_config_path=self.hf_config_path, + hf_config_path=hf_config_path, runner=self.runner, convert=self.convert, tokenizer=self.tokenizer, @@ -275,6 +279,7 @@ def create_model_config(self) -> OmniModelConfig: # register omni models to avoid model not found error self._ensure_omni_models_registered() + hf_config_path = _resolve_bundled_hf_config_path(self.model_arch, self.hf_config_path) # Keep compatibility when async args are constructed from partial payloads. limit_mm_per_prompt = getattr(self, "limit_mm_per_prompt", {}) @@ -305,7 +310,7 @@ def create_model_config(self) -> OmniModelConfig: # Base ModelConfig fields (matching vLLM's EngineArgs.create_model_config) model=self.model, model_weights=self.model_weights, - hf_config_path=self.hf_config_path, + hf_config_path=hf_config_path, runner=self.runner, convert=self.convert, tokenizer=self.tokenizer, diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py index 5bbd16b38d3..0d080d95209 100644 --- a/vllm_omni/engine/input_processor.py +++ b/vllm_omni/engine/input_processor.py @@ -11,6 +11,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.utils import argsort_mm_positions +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.renderers import BaseRenderer from vllm.sampling_params import SamplingParams @@ -184,7 +185,11 @@ def process_inputs( tokenization_kwargs=tokenization_kwargs, ) - self._platform_validate_request(processed_inputs, params) + platform_validate_request = getattr(self, "_platform_validate_request", None) + if platform_validate_request is not None: + platform_validate_request(processed_inputs, params) + else: + current_platform.validate_request(processed_inputs, params) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs) diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 488b986d8ee..9f0b4738fda 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -444,6 +444,15 @@ def _is_async_chunk_enable(self, stage_args: list) -> bool: engine_args = getattr(stage_args[0], "engine_args", None) return bool(getattr(engine_args, "async_chunk", False)) + @staticmethod + def _get_stage_model(stage: OmniStage, fallback_model: str) -> str: + """Prefer a per-stage model override when present.""" + engine_args = getattr(stage, "engine_args", None) + stage_model = getattr(engine_args, "model", None) + if not stage_model and isinstance(engine_args, dict): + stage_model = engine_args.get("model") + return stage_model or fallback_model + def _start_stages(self, model: str) -> None: """Start all stage processes.""" if self.worker_backend == "ray": @@ -510,12 +519,14 @@ def _start_stages(self, model: str) -> None: ) continue + stage_model = self._get_stage_model(stage, model) stage.init_stage_worker( - model, + stage_model, is_async=self.is_async, shm_threshold_bytes=self._shm_threshold_bytes, ctx=self._ctx if self.worker_backend != "ray" else None, batch_timeout=self.batch_timeout, + stage_configs_path=self.config_path, connectors_config=stage_connectors_config, worker_backend=self.worker_backend, ray_placement_group=self._ray_pg, diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 1298322676c..d6a278aa37d 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -496,6 +496,7 @@ def init_stage_worker( shm_threshold_bytes: int = 65536, ctx: mp.context.BaseContext | None = None, batch_timeout: int = 10, + stage_configs_path: str | None = None, connectors_config: dict | None = None, worker_backend: str = "multi_process", ignore_runtime_config: bool = False, @@ -540,6 +541,7 @@ def init_stage_worker( "stage_id": self.stage_id, "engine_args": engine_args, "runtime": runtime_cfg, + "stage_configs_path": stage_configs_path, "shm_threshold_bytes": self._shm_threshold_bytes, "connectors_config": connectors_config or {}, "stage_type": self.stage_type, @@ -825,6 +827,7 @@ def _stage_worker( stage_id = stage_payload["stage_id"] engine_args = stage_payload.get("engine_args", {}) runtime_cfg = stage_payload.get("runtime", {}) + stage_configs_path = stage_payload.get("stage_configs_path") shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) connectors_config = stage_payload.get("connectors_config", {}) stage_type: Literal["llm", "diffusion"] = stage_payload.get("stage_type", "llm") @@ -896,6 +899,8 @@ def _stage_worker( else: engine_args = filter_dataclass_kwargs(OmniEngineArgs, engine_args) engine_args.pop("model", None) + if stage_configs_path is not None: + engine_args["stage_configs_path"] = stage_configs_path # Default to LLM engine stage_engine = OmniLLM(model=model, **engine_args) diff --git a/vllm_omni/model_executor/models/__init__.py b/vllm_omni/model_executor/models/__init__.py index 68074bbb996..97df6f16573 100644 --- a/vllm_omni/model_executor/models/__init__.py +++ b/vllm_omni/model_executor/models/__init__.py @@ -1,8 +1,11 @@ from .bagel.bagel import OmniBagelForConditionalGeneration +from .funaudiochat import FunAudioChatCosyVoice3Code2Wav, FunAudioChatForConditionalGeneration from .qwen3_omni import Qwen3OmniMoeForConditionalGeneration from .registry import OmniModelRegistry # noqa: F401 __all__ = [ + "FunAudioChatCosyVoice3Code2Wav", + "FunAudioChatForConditionalGeneration", "Qwen3OmniMoeForConditionalGeneration", "OmniBagelForConditionalGeneration", ] diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py index f5e0d04a8ae..1b509b71a3c 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py @@ -76,7 +76,6 @@ def __init__(self, config: CosyVoice3Config): pre_lookahead_layer=pre_lookahead_layer, decoder=decoder, ) - # Build HiFiGAN vocoder f0_predictor = CausalConvRNNF0Predictor( num_class=config.hift["f0_predictor"]["num_class"], diff --git a/vllm_omni/model_executor/models/cosyvoice3/hf_config/config.json b/vllm_omni/model_executor/models/cosyvoice3/hf_config/config.json new file mode 100644 index 00000000000..2e803578153 --- /dev/null +++ b/vllm_omni/model_executor/models/cosyvoice3/hf_config/config.json @@ -0,0 +1,3 @@ +{ + "model_type": "cosyvoice3" +} diff --git a/vllm_omni/model_executor/models/cosyvoice3/hf_config_utils.py b/vllm_omni/model_executor/models/cosyvoice3/hf_config_utils.py new file mode 100644 index 00000000000..02143c9768b --- /dev/null +++ b/vllm_omni/model_executor/models/cosyvoice3/hf_config_utils.py @@ -0,0 +1,17 @@ +from pathlib import Path + +_COSYVOICE3_MODEL_ARCHES = { + "CosyVoice3Model", + "FunAudioChatCosyVoice3Code2Wav", +} + + +def resolve_bundled_hf_config_path(model_arch: str, hf_config_path: str | None) -> str | None: + if hf_config_path is not None or model_arch not in _COSYVOICE3_MODEL_ARCHES: + return hf_config_path + + bundled_hf_config_path = Path(__file__).resolve().parent / "hf_config" + if not (bundled_hf_config_path / "config.json").is_file(): + return None + + return str(bundled_hf_config_path) diff --git a/vllm_omni/model_executor/models/funaudiochat/__init__.py b/vllm_omni/model_executor/models/funaudiochat/__init__.py new file mode 100644 index 00000000000..e05bde83236 --- /dev/null +++ b/vllm_omni/model_executor/models/funaudiochat/__init__.py @@ -0,0 +1,4 @@ +from .funaudiochat import FunAudioChatForConditionalGeneration +from .funaudiochat_code2wav import FunAudioChatCosyVoice3Code2Wav + +__all__ = ["FunAudioChatForConditionalGeneration", "FunAudioChatCosyVoice3Code2Wav"] diff --git a/vllm_omni/model_executor/models/funaudiochat/common.py b/vllm_omni/model_executor/models/funaudiochat/common.py new file mode 100644 index 00000000000..d86872573dd --- /dev/null +++ b/vllm_omni/model_executor/models/funaudiochat/common.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import os +import sys +from collections.abc import Mapping, Sequence +from functools import cached_property +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from transformers import PreTrainedTokenizerFast, WhisperFeatureExtractor +from transformers.feature_extraction_utils import BatchFeature +from vllm.config.multimodal import BaseDummyOptions +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems +from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) + + +def ensure_funaudiochat_importable() -> Any: + try: + import funaudiochat # type: ignore + + return funaudiochat + except ImportError: + pass + + env_home = os.environ.get("FUN_AUDIO_CHAT_HOME") + extra_candidates = [Path(env_home).expanduser()] if env_home else [] + + for candidate in extra_candidates: + if candidate and candidate.exists(): + sys.path.insert(0, str(candidate)) + try: + import funaudiochat # type: ignore + + return funaudiochat + except ImportError: + continue + + raise ImportError( + "funaudiochat package is required. Install Fun-Audio-Chat into the active " + "environment or set FUN_AUDIO_CHAT_HOME to the repo checkout." + ) + + +def resolve_funaudiochat_root() -> Path: + pkg = ensure_funaudiochat_importable() + pkg_path = Path(pkg.__file__).resolve() + root = pkg_path.parent.parent + if not root.exists(): + raise FileNotFoundError(f"Resolved Fun-Audio-Chat root does not exist: {root}") + return root + + +class FunAudioChatProcessingInfo(BaseProcessingInfo): + token_fps: int = 25 + + @cached_property + def feature_extractor(self) -> WhisperFeatureExtractor: + return WhisperFeatureExtractor.from_pretrained(self.model_id) + + @cached_property + def speech_tokenizer(self) -> PreTrainedTokenizerFast: + return PreTrainedTokenizerFast.from_pretrained(self.model_id, subfolder="speech_tokenizer") + + def get_feature_extractor(self) -> WhisperFeatureExtractor: + return self.feature_extractor + + def get_speech_tokenizer(self) -> PreTrainedTokenizerFast: + return self.speech_tokenizer + + def get_data_parser(self): + return MultiModalDataParser( + target_sr=int(self.feature_extractor.sampling_rate), + target_channels=1, + expected_hidden_size=self._get_expected_hidden_size(), + ) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": None} + + def get_mm_max_tokens_per_item(self, seq_len: int, mm_counts: Mapping[str, int]) -> Mapping[str, int] | None: + del seq_len, mm_counts + cfg = self.get_hf_config() + audio_cfg = getattr(cfg, "audio_config", None) + max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500)) + return {"audio": max_audio_tokens} + + def get_audio_group_size(self) -> int: + cfg = self.get_hf_config() + audio_cfg = getattr(cfg, "audio_config", None) + return int(getattr(audio_cfg, "group_size", 5)) + + +class FunAudioChatDummyInputsBuilder(BaseDummyInputsBuilder[FunAudioChatProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + return "<|audio_bos|><|AUDIO|><|audio_eos|>" * int(num_audios) + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + del seq_len + feature_extractor = self.info.get_feature_extractor() + sampling_rate = int(feature_extractor.sampling_rate) + cfg = self.info.get_hf_config() + audio_cfg = getattr(cfg, "audio_config", None) + max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500)) + group_size = self.info.get_audio_group_size() + token_fps = int(getattr(self.info, "token_fps", 25)) + target_num_frames = max(1, max_audio_tokens) * max(1, group_size) + audio_len = max(1, (target_num_frames * sampling_rate + token_fps - 1) // token_fps) + num_audios = int(mm_counts.get("audio", 0)) + audio_overrides = mm_options.get("audio") if mm_options else None + return { + "audio": self._get_dummy_audios( + length=audio_len, + num_audios=num_audios, + overrides=audio_overrides, + ) + } + + +class FunAudioChatMultiModalProcessor(BaseMultiModalProcessor[FunAudioChatProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + del mm_kwargs + tokenizer = self.info.get_tokenizer() + text_inputs = tokenizer( + prompt, + return_attention_mask=True, + return_token_type_ids=False, + return_tensors="pt", + **tok_kwargs, + ) + + audios = mm_data.get("audios", []) + if not audios: + return BatchFeature( + { + "input_ids": text_inputs["input_ids"], + "attention_mask": text_inputs["attention_mask"], + } + ) + + feature_extractor = self.info.get_feature_extractor() + sr = int(feature_extractor.sampling_rate) + min_samples = int(getattr(feature_extractor, "n_fft", 400) or 400) + + wavs: list[np.ndarray] = [] + speech_strs: list[str] = [] + + speech_tokenizer = self.info.get_speech_tokenizer() + pad_token = speech_tokenizer.pad_token or "<|audio_pad|>" + for audio in audios: + if isinstance(audio, torch.Tensor): + audio = audio.detach().cpu().numpy() + audio_np = np.asarray(audio, dtype=np.float32) + if min_samples > 0 and audio_np.shape[0] < min_samples: + audio_np = np.pad(audio_np, (0, min_samples - audio_np.shape[0]), mode="constant") + + wavs.append(audio_np) + num_frames = int((float(audio_np.shape[0]) / float(sr)) * float(self.info.token_fps)) + speech_strs.append(pad_token * max(1, int(num_frames))) + + audio_group_size = self.info.get_audio_group_size() + speech_inputs = speech_tokenizer( + speech_strs, + return_attention_mask=True, + return_token_type_ids=False, + padding=True, + pad_to_multiple_of=audio_group_size, + return_tensors="pt", + ) + wav_inputs = feature_extractor( + wavs, + sampling_rate=sr, + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + ) + + return BatchFeature( + { + "input_ids": text_inputs["input_ids"], + "attention_mask": text_inputs["attention_mask"], + "speech_ids": speech_inputs["input_ids"], + "speech_attention_mask": speech_inputs["attention_mask"], + "input_features": wav_inputs["input_features"], + "feature_attention_mask": wav_inputs["attention_mask"], + "feature_exist_mask": torch.ones((len(wavs),), dtype=torch.bool), + } + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + del prompt_text, mm_items, hf_processor_mm_kwargs, tokenization_kwargs + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + del hf_inputs, hf_processor_mm_kwargs + return { + "speech_ids": MultiModalFieldConfig.batched("audio"), + "speech_attention_mask": MultiModalFieldConfig.batched("audio"), + "input_features": MultiModalFieldConfig.batched("audio"), + "feature_attention_mask": MultiModalFieldConfig.batched("audio"), + "feature_exist_mask": MultiModalFieldConfig.batched("audio"), + } + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + del hf_processor_mm_kwargs + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + audio_token = "<|AUDIO|>" + audio_token_id = vocab[audio_token] + out_mm_data = out_mm_kwargs.get_data() + speech_attention_mask = out_mm_data.get("speech_attention_mask") + if speech_attention_mask is None: + audio_output_lengths: list[int] = [] + else: + assert isinstance(speech_attention_mask, torch.Tensor) + speech_lengths = speech_attention_mask.sum(-1) + group_size = self.info.get_audio_group_size() + audio_output_lengths = ((speech_lengths + group_size - 1) // group_size).tolist() + + def get_replacement(item_idx: int): + num_features = int(audio_output_lengths[item_idx]) if audio_output_lengths else 1 + if num_features <= 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio_len = audios.get_audio_length(item_idx) + raise ValueError(f"The audio (len={audio_len}) is too short to be represented inside the model") + audio_tokens = [audio_token_id] * num_features + return PromptUpdateDetails.select_token_id(audio_tokens, embed_token_id=audio_token_id) + + return [PromptReplacement(modality="audio", target=audio_token, replacement=get_replacement)] + + +def register_funaudiochat_processor(model_cls: type[Any]) -> type[Any]: + return MULTIMODAL_REGISTRY.register_processor( + FunAudioChatMultiModalProcessor, + info=FunAudioChatProcessingInfo, + dummy_inputs=FunAudioChatDummyInputsBuilder, + )(model_cls) diff --git a/vllm_omni/model_executor/models/funaudiochat/funaudiochat.py b/vllm_omni/model_executor/models/funaudiochat/funaudiochat.py new file mode 100644 index 00000000000..83883d0a1e7 --- /dev/null +++ b/vllm_omni/model_executor/models/funaudiochat/funaudiochat.py @@ -0,0 +1,622 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from collections.abc import Iterable +from types import MethodType +from typing import Any + +import torch +import torch.nn as nn +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) +from transformers.modeling_outputs import BaseModelOutput +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.model_executor.models.funaudiochat.common import ( + ensure_funaudiochat_importable, + register_funaudiochat_processor, +) + +try: + from vllm.model_executor.models.funaudiochat import ( + FunAudioChatForConditionalGeneration as VllmNativeFunAudioChatForConditionalGeneration, + ) +except ImportError: # pragma: no cover - environment-specific dependency + VllmNativeFunAudioChatForConditionalGeneration = None + +_NativeFunAudioChatBase = ( + VllmNativeFunAudioChatForConditionalGeneration + if VllmNativeFunAudioChatForConditionalGeneration is not None + else nn.Module +) + +logger = init_logger(__name__) + +DEFAULT_SP_GEN_KWARGS = { + "text_greedy": True, + "only_crq_sampling": True, + "disable_speech": False, + "force_text_abos": True, +} + +_OFFICIAL_CRQ_SAMPLING_DEFAULTS = { + "repetition_penalty": 1.2, + "temperature": 0.8, + "top_p": 0.9, + "top_k": 0, +} + +_AUDIO_TOKEN_IDS_KEY = "funaudiochat_audio_token_ids" +_CRQ_AUDIO_EMBEDS_KEY = "funaudiochat_crq_audio_embeds" +_CRQ_PAST_KEY_VALUES_KEY = "funaudiochat_crq_past_key_values" +_CURRENT_INPUT_TOKEN_ID_KEY = "funaudiochat_current_input_token_id" +_FORCE_AUDIO_BOS_KEY = "funaudiochat_force_audio_bos_pending" +_FINISH_SPEECH_KEY = "funaudiochat_finish_speech" +_GENERATE_SPEECH_KEY = "funaudiochat_generate_speech" +_RAW_TEXT_TOKEN_ID_KEY = "funaudiochat_raw_text_token_id" +_SPEECH_IDS_KEY = "funaudiochat_speech_ids" +_TEXT_INPUT_IDS_KEY = "funaudiochat_text_input_ids" +_TEXT_SEQ_LEN_KEY = "funaudiochat_text_seq_len" + + +@register_funaudiochat_processor +class FunAudioChatForConditionalGeneration(_NativeFunAudioChatBase, SupportsMultiModal): + supports_multimodal_raw_input_only = True + supports_multimodal = True + requires_raw_input_tokens = True + input_modalities = "audio" + pooler_output_buffer_keys = ("audio_token_ids",) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + if VllmNativeFunAudioChatForConditionalGeneration is None: + raise ImportError( + "Installed vLLM does not expose a native FunAudioChat model. " + "Upgrade vLLM to a build that includes " + "`vllm.model_executor.models.funaudiochat`." + ) + + super().__init__(vllm_config=vllm_config, prefix=prefix) + ensure_funaudiochat_importable() + from funaudiochat.modeling_funaudiochat import FunAudioChatDecoder # type: ignore + + self.audio_invert_tower = FunAudioChatDecoder(self.config.audio_config) + self._patch_audio_invert_tower_sampling_step() + self.sp_gen_kwargs = DEFAULT_SP_GEN_KWARGS.copy() + self.has_preprocess = True + self.has_postprocess = True + self.have_multimodal_outputs = False + self._batch_preprocess_in_progress = False + self._batch_req_infos: list[dict[str, Any]] = [] + self._batch_sidecar_results: list[dict[str, Any]] = [] + self._postprocess_cursor = 0 + self._logged_stage0_backend = False + + @staticmethod + def _move_nested_to_cpu(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().to("cpu").contiguous() + if isinstance(value, tuple): + return tuple(FunAudioChatForConditionalGeneration._move_nested_to_cpu(v) for v in value) + if isinstance(value, list): + return [FunAudioChatForConditionalGeneration._move_nested_to_cpu(v) for v in value] + return value + + @staticmethod + def _move_nested_to_device(value: Any, device: torch.device) -> Any: + if isinstance(value, torch.Tensor): + return value.to(device=device) + if isinstance(value, tuple): + return tuple(FunAudioChatForConditionalGeneration._move_nested_to_device(v, device) for v in value) + if isinstance(value, list): + return [FunAudioChatForConditionalGeneration._move_nested_to_device(v, device) for v in value] + return value + + @staticmethod + def _as_2d_long_tensor(value: Any, device: torch.device) -> torch.Tensor: + if value is None: + return torch.empty((1, 0), dtype=torch.long, device=device) + if isinstance(value, torch.Tensor): + tensor = value.to(device=device, dtype=torch.long) + else: + tensor = torch.as_tensor(value, dtype=torch.long, device=device) + if tensor.ndim == 0: + tensor = tensor.reshape(1, 1) + elif tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + return tensor + + def _patch_audio_invert_tower_sampling_step(self) -> None: + if getattr(self.audio_invert_tower, "_vllm_omni_crq_generator_patched", False): + return + + def _sampling_step_with_generator( + decoder_self: nn.Module, + logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + next_token_logits = logits[:, -1, :].to(copy=True, dtype=torch.float32, device=logits.device) + next_token_scores = decoder_self.crq_logits_processor( + torch.cat([decoder_self.crq_speech_ids, *decoder_self.crq_generate_tokens], dim=-1), + next_token_logits, + ) + + if decoder_self.crq_do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + return next_tokens, logits + + self.audio_invert_tower.sampling_step = MethodType(_sampling_step_with_generator, self.audio_invert_tower) + self.audio_invert_tower._vllm_omni_crq_generator_patched = True + + def _empty_audio_token_ids(self, device: torch.device) -> torch.Tensor: + return torch.full( + (1, int(self.config.audio_config.group_size)), + -1, + dtype=torch.long, + device=device, + ) + + @staticmethod + def _sampling_value_at( + value: torch.Tensor | None, + req_index: int, + default: float, + ) -> float: + if value is None: + return float(default) + if value.ndim == 0: + return float(value.item()) + if req_index >= value.shape[0]: + return float(default) + return float(value[req_index].item()) + + @staticmethod + def _resolve_text_seq_len( + prev_text_seq_len: Any, + span_len: int, + ) -> tuple[int, int]: + prev = int(prev_text_seq_len or 0) + if span_len > 1: + current = prev + span_len + return current, current + current = prev if prev > 0 else 1 + return current, current + 1 + + @staticmethod + def _resolve_next_speech_state( + *, + sampled_token_id: int, + generate_speech: bool, + finish_speech: bool, + force_audio_bos_pending: bool, + audio_bos_id: int, + audio_eos_id: int, + ) -> tuple[int, bool, bool]: + if finish_speech: + return audio_eos_id, False, False + + final_token_id = audio_bos_id if force_audio_bos_pending else sampled_token_id + next_speech_active = generate_speech or final_token_id == audio_bos_id + if final_token_id == audio_eos_id: + next_speech_active = False + + return final_token_id, next_speech_active, False + + def _build_crq_sampling_config( + self, + sampling_metadata: Any, + req_index: int, + ) -> tuple[LogitsProcessorList, bool]: + repetition_penalty = self._sampling_value_at( + getattr(sampling_metadata, "repetition_penalties", None) if sampling_metadata is not None else None, + req_index, + _OFFICIAL_CRQ_SAMPLING_DEFAULTS["repetition_penalty"], + ) + default_temperature = 0.0 + default_top_p = 1.0 + default_top_k = -1.0 + if self.sp_gen_kwargs["text_greedy"]: + default_temperature = _OFFICIAL_CRQ_SAMPLING_DEFAULTS["temperature"] + default_top_p = _OFFICIAL_CRQ_SAMPLING_DEFAULTS["top_p"] + default_top_k = float(_OFFICIAL_CRQ_SAMPLING_DEFAULTS["top_k"]) + + temperature = self._sampling_value_at( + getattr(sampling_metadata, "temperature", None) if sampling_metadata is not None else None, + req_index, + default_temperature, + ) + top_p = self._sampling_value_at( + getattr(sampling_metadata, "top_p", None) if sampling_metadata is not None else None, + req_index, + default_top_p, + ) + top_k = int( + round( + self._sampling_value_at( + getattr(sampling_metadata, "top_k", None) if sampling_metadata is not None else None, + req_index, + default_top_k, + ) + ) + ) + + if self.sp_gen_kwargs["text_greedy"] and temperature <= 0.0: + temperature = float(_OFFICIAL_CRQ_SAMPLING_DEFAULTS["temperature"]) + if top_p >= 1.0: + top_p = float(_OFFICIAL_CRQ_SAMPLING_DEFAULTS["top_p"]) + if top_k < 0: + top_k = int(_OFFICIAL_CRQ_SAMPLING_DEFAULTS["top_k"]) + + processors: list[Any] = [] + if repetition_penalty > 0.0 and abs(repetition_penalty - 1.0) > 1e-6: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + + do_sample = temperature > 0.0 + if do_sample: + if abs(temperature - 1.0) > 1e-6: + processors.append(TemperatureLogitsWarper(temperature)) + if top_k > 0: + processors.append(TopKLogitsWarper(top_k=top_k)) + if 0.0 < top_p < 1.0: + processors.append(TopPLogitsWarper(top_p=top_p)) + + return LogitsProcessorList(processors), do_sample + + def _get_stage0_backend(self) -> str: + try: + backend_cls = self.get_language_model().model.layers[0].self_attn.attn.get_attn_backend() + backend_name = str(backend_cls.get_name()) + except Exception: + backend_name = "UNKNOWN" + if not self._logged_stage0_backend: + logger.info("FunAudioChat stage-0 native language backend: %s", backend_name) + self._logged_stage0_backend = True + return backend_name + + def _run_audio_sidecar_step( + self, + hidden_state: torch.Tensor, + current_input_token_id: int, + speech_ids: torch.Tensor, + cached_audio_embeds: Any, + cached_past_key_values: Any, + logits_processor: LogitsProcessorList, + do_sample: bool, + current_text_seq_len: int, + ) -> dict[str, Any]: + device = hidden_state.device + text_embed = ( + self.get_language_model() + .embed_input_ids(torch.tensor([current_input_token_id], device=device, dtype=torch.long)) + .reshape(1, 1, -1) + ) + speech_inputs_embeds = hidden_state.reshape(1, 1, -1) + text_embed.detach() + attention_mask = torch.ones((1, max(current_text_seq_len, 1)), dtype=torch.long, device=device) + position_ids = torch.tensor([[max(current_text_seq_len - 1, 0)]], dtype=torch.long, device=device) + + self.audio_invert_tower.crq_audio_embeds = self._move_nested_to_device(cached_audio_embeds, device) + self.audio_invert_tower.crq_past_key_values = self._move_nested_to_device(cached_past_key_values, device) + self.audio_invert_tower.crq_do_sample = do_sample + self.audio_invert_tower.crq_logits_processor = logits_processor + self.audio_invert_tower.crq_speech_ids = speech_ids + self.audio_invert_tower.crq_generate_forward( + inputs_embeds=speech_inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + ) + + next_audio_tokens = self.audio_invert_tower.crq_generate_tokens.reshape(1, -1).to(dtype=torch.long) + eos_token_id = int(self.config.audio_config.eos_token_id) + finish_speech = bool((next_audio_tokens == eos_token_id).any().item()) + if finish_speech: + next_audio_tokens = torch.full_like(next_audio_tokens, eos_token_id) + + updated_speech_ids = torch.cat([speech_ids, next_audio_tokens], dim=-1) + return { + _AUDIO_TOKEN_IDS_KEY: next_audio_tokens.detach(), + _CRQ_AUDIO_EMBEDS_KEY: self._move_nested_to_cpu(self.audio_invert_tower.crq_audio_embeds), + _CRQ_PAST_KEY_VALUES_KEY: self._move_nested_to_cpu(self.audio_invert_tower.crq_past_key_values), + _FINISH_SPEECH_KEY: finish_speech, + _SPEECH_IDS_KEY: updated_speech_ids.detach().to("cpu").contiguous(), + } + + def _run_audio_sidecar_decode_warmup( + self, + hidden_state: torch.Tensor, + current_input_token_id: int, + speech_ids: torch.Tensor, + cached_audio_embeds: Any, + cached_past_key_values: Any, + logits_processor: LogitsProcessorList, + do_sample: bool, + ) -> dict[str, Any]: + device = hidden_state.device + text_embed = ( + self.get_language_model() + .embed_input_ids(torch.tensor([current_input_token_id], device=device, dtype=torch.long)) + .reshape(1, 1, -1) + ) + speech_inputs_embeds = hidden_state.reshape(1, 1, -1) + text_embed.detach() + + self.audio_invert_tower.crq_audio_embeds = self._move_nested_to_device(cached_audio_embeds, device) + self.audio_invert_tower.crq_past_key_values = self._move_nested_to_device(cached_past_key_values, device) + self.audio_invert_tower.crq_do_sample = do_sample + self.audio_invert_tower.crq_logits_processor = logits_processor + self.audio_invert_tower.crq_speech_ids = speech_ids + self.audio_invert_tower.crq_generate_forward( + inputs_embeds=speech_inputs_embeds, + return_dict=True, + ) + return { + _CRQ_AUDIO_EMBEDS_KEY: self._move_nested_to_cpu(self.audio_invert_tower.crq_audio_embeds), + _CRQ_PAST_KEY_VALUES_KEY: self._move_nested_to_cpu(self.audio_invert_tower.crq_past_key_values), + } + + def _run_audio_sidecar_prefill_warmup( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + speech_ids: torch.Tensor, + cached_audio_embeds: Any, + cached_past_key_values: Any, + logits_processor: LogitsProcessorList, + do_sample: bool, + ) -> dict[str, Any]: + device = hidden_states.device + input_ids = input_ids.to(device=device, dtype=torch.long).reshape(1, -1) + text_embeds = ( + self.get_language_model() + .embed_input_ids(input_ids.reshape(-1)) + .reshape( + 1, + -1, + hidden_states.shape[-1], + ) + ) + speech_inputs_embeds = hidden_states.reshape(1, -1, hidden_states.shape[-1]) + text_embeds.detach() + + self.audio_invert_tower.crq_audio_embeds = self._move_nested_to_device(cached_audio_embeds, device) + self.audio_invert_tower.crq_past_key_values = self._move_nested_to_device(cached_past_key_values, device) + self.audio_invert_tower.crq_do_sample = do_sample + self.audio_invert_tower.crq_logits_processor = logits_processor + self.audio_invert_tower.crq_speech_ids = speech_ids + self.audio_invert_tower.crq_generate_forward( + inputs_embeds=speech_inputs_embeds, + return_dict=True, + ) + return { + _CRQ_AUDIO_EMBEDS_KEY: self._move_nested_to_cpu(self.audio_invert_tower.crq_audio_embeds), + _CRQ_PAST_KEY_VALUES_KEY: self._move_nested_to_cpu(self.audio_invert_tower.crq_past_key_values), + } + + def preprocess( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor | None, + **info_dict: Any, + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + if not self._batch_preprocess_in_progress: + self._batch_req_infos = [] + self._batch_sidecar_results = [] + self._postprocess_cursor = 0 + self._batch_preprocess_in_progress = True + + span_len = int(input_ids.shape[0]) + device = input_ids.device + req_embeds = input_embeds + if req_embeds is None and span_len > 1: + req_embeds = self.get_language_model().embed_input_ids(input_ids.reshape(-1)) + + generate_speech = bool(info_dict.get(_GENERATE_SPEECH_KEY, False)) + force_audio_bos_pending = bool(info_dict.get(_FORCE_AUDIO_BOS_KEY, self.sp_gen_kwargs["force_text_abos"])) + speech_ids = self._as_2d_long_tensor(info_dict.get(_SPEECH_IDS_KEY), device) + current_text_seq_len, next_text_seq_len = self._resolve_text_seq_len( + info_dict.get(_TEXT_SEQ_LEN_KEY), + span_len, + ) + + if span_len == 1: + current_text_embed = self.get_language_model().embed_input_ids(input_ids.reshape(-1)).reshape(1, -1) + if generate_speech and speech_ids.shape[-1] >= int(self.config.audio_config.group_size): + last_group = speech_ids[:, -int(self.config.audio_config.group_size) :] + audio_features = self.audio_tower(last_group.to(device=device, dtype=torch.long)) + if isinstance(audio_features, BaseModelOutput): + audio_features = audio_features.last_hidden_state + elif isinstance(audio_features, (tuple, list)): + audio_features = audio_features[0] + req_embeds = (current_text_embed + audio_features.reshape(1, -1)) / 2 + + current_input_token_id = int(input_ids.reshape(-1)[-1].item()) + self._get_stage0_backend() + update_dict = { + _CURRENT_INPUT_TOKEN_ID_KEY: current_input_token_id, + _FORCE_AUDIO_BOS_KEY: force_audio_bos_pending, + _GENERATE_SPEECH_KEY: generate_speech, + _SPEECH_IDS_KEY: speech_ids.detach().to("cpu").contiguous(), + _TEXT_SEQ_LEN_KEY: next_text_seq_len, + "audio_token_ids": self._empty_audio_token_ids(device).to("cpu"), + } + + self._batch_req_infos.append( + { + _CURRENT_INPUT_TOKEN_ID_KEY: current_input_token_id, + _FORCE_AUDIO_BOS_KEY: force_audio_bos_pending, + _GENERATE_SPEECH_KEY: generate_speech, + _SPEECH_IDS_KEY: speech_ids.detach().to("cpu").contiguous(), + _CRQ_AUDIO_EMBEDS_KEY: info_dict.get(_CRQ_AUDIO_EMBEDS_KEY), + _CRQ_PAST_KEY_VALUES_KEY: info_dict.get(_CRQ_PAST_KEY_VALUES_KEY), + _TEXT_INPUT_IDS_KEY: input_ids.detach().to("cpu").contiguous(), + _TEXT_SEQ_LEN_KEY: current_text_seq_len, + } + ) + return input_ids, req_embeds, update_dict + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: Any = None, + ) -> torch.Tensor | None: + logits = super().compute_logits(hidden_states) + if logits is None: + self._batch_preprocess_in_progress = False + return None + + raw_argmax_token_ids = torch.argmax(logits, dim=-1).detach().to("cpu") + self._batch_sidecar_results = [] + for idx, req_info in enumerate(self._batch_req_infos): + force_audio_bos_pending = bool(req_info.get(_FORCE_AUDIO_BOS_KEY, False)) + speech_active = bool(req_info.get(_GENERATE_SPEECH_KEY, False)) + raw_text_token_id = int(raw_argmax_token_ids[idx].item()) + + sidecar_result = { + _AUDIO_TOKEN_IDS_KEY: self._empty_audio_token_ids(hidden_states.device).to("cpu"), + _CRQ_AUDIO_EMBEDS_KEY: req_info.get(_CRQ_AUDIO_EMBEDS_KEY), + _CRQ_PAST_KEY_VALUES_KEY: req_info.get(_CRQ_PAST_KEY_VALUES_KEY), + _FORCE_AUDIO_BOS_KEY: force_audio_bos_pending, + _FINISH_SPEECH_KEY: False, + _GENERATE_SPEECH_KEY: speech_active, + _RAW_TEXT_TOKEN_ID_KEY: raw_text_token_id, + _SPEECH_IDS_KEY: req_info.get(_SPEECH_IDS_KEY), + "audio_token_ids": self._empty_audio_token_ids(hidden_states.device).to("cpu"), + } + + req_input_ids = self._as_2d_long_tensor(req_info.get(_TEXT_INPUT_IDS_KEY), hidden_states.device).reshape(-1) + crq_logits_processor, do_sample = self._build_crq_sampling_config( + sampling_metadata=sampling_metadata, + req_index=idx, + ) + if speech_active and not self.sp_gen_kwargs["disable_speech"]: + sidecar_step = self._run_audio_sidecar_step( + hidden_state=hidden_states[idx], + current_input_token_id=int(req_info[_CURRENT_INPUT_TOKEN_ID_KEY]), + speech_ids=self._as_2d_long_tensor(req_info.get(_SPEECH_IDS_KEY), hidden_states.device), + cached_audio_embeds=req_info.get(_CRQ_AUDIO_EMBEDS_KEY), + cached_past_key_values=req_info.get(_CRQ_PAST_KEY_VALUES_KEY), + logits_processor=crq_logits_processor, + do_sample=do_sample, + current_text_seq_len=int(req_info.get(_TEXT_SEQ_LEN_KEY, 1)), + ) + sidecar_result.update(sidecar_step) + sidecar_result["audio_token_ids"] = sidecar_step[_AUDIO_TOKEN_IDS_KEY] + elif not self.sp_gen_kwargs["disable_speech"]: + if req_input_ids.numel() > 1: + sidecar_result["_run_prefill_crq_warmup"] = True + sidecar_result["_prefill_input_ids"] = req_info.get(_TEXT_INPUT_IDS_KEY) + sidecar_result["_prefill_crq_logits_processor"] = crq_logits_processor + sidecar_result["_prefill_crq_do_sample"] = do_sample + else: + warmup_state = self._run_audio_sidecar_decode_warmup( + hidden_state=hidden_states[idx], + current_input_token_id=int(req_info[_CURRENT_INPUT_TOKEN_ID_KEY]), + speech_ids=self._as_2d_long_tensor(req_info.get(_SPEECH_IDS_KEY), hidden_states.device), + cached_audio_embeds=req_info.get(_CRQ_AUDIO_EMBEDS_KEY), + cached_past_key_values=req_info.get(_CRQ_PAST_KEY_VALUES_KEY), + logits_processor=crq_logits_processor, + do_sample=do_sample, + ) + sidecar_result.update(warmup_state) + + self._batch_sidecar_results.append(sidecar_result) + self._postprocess_cursor = 0 + self._batch_preprocess_in_progress = False + return logits + + def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: + if self._postprocess_cursor >= len(self._batch_sidecar_results): + return {} + sidecar_result = self._batch_sidecar_results[self._postprocess_cursor] + self._postprocess_cursor += 1 + if bool(sidecar_result.pop("_run_prefill_crq_warmup", False)): + prefill_input_ids = sidecar_result.pop("_prefill_input_ids", None) + if prefill_input_ids is not None: + warmup_state = self._run_audio_sidecar_prefill_warmup( + hidden_states=hidden_states, + input_ids=self._as_2d_long_tensor(prefill_input_ids, hidden_states.device).reshape(-1), + speech_ids=self._as_2d_long_tensor(sidecar_result.get(_SPEECH_IDS_KEY), hidden_states.device), + cached_audio_embeds=sidecar_result.get(_CRQ_AUDIO_EMBEDS_KEY), + cached_past_key_values=sidecar_result.get(_CRQ_PAST_KEY_VALUES_KEY), + logits_processor=sidecar_result.pop("_prefill_crq_logits_processor"), + do_sample=bool(sidecar_result.pop("_prefill_crq_do_sample", False)), + ) + sidecar_result.update(warmup_state) + return { + _AUDIO_TOKEN_IDS_KEY: sidecar_result[_AUDIO_TOKEN_IDS_KEY], + _CRQ_AUDIO_EMBEDS_KEY: sidecar_result[_CRQ_AUDIO_EMBEDS_KEY], + _CRQ_PAST_KEY_VALUES_KEY: sidecar_result[_CRQ_PAST_KEY_VALUES_KEY], + _FORCE_AUDIO_BOS_KEY: sidecar_result[_FORCE_AUDIO_BOS_KEY], + _FINISH_SPEECH_KEY: sidecar_result[_FINISH_SPEECH_KEY], + _GENERATE_SPEECH_KEY: sidecar_result[_GENERATE_SPEECH_KEY], + _RAW_TEXT_TOKEN_ID_KEY: sidecar_result[_RAW_TEXT_TOKEN_ID_KEY], + _SPEECH_IDS_KEY: sidecar_result[_SPEECH_IDS_KEY], + "audio_token_ids": sidecar_result["audio_token_ids"], + } + + def postprocess_sampled_tokens( + self, + sampled_token_ids: torch.Tensor, + req_ids: list[str], + req_id_to_index: dict[str, int], + model_intermediate_buffer: dict[str, dict[str, Any]], + ) -> torch.Tensor: + if sampled_token_ids.numel() == 0: + return sampled_token_ids + + if sampled_token_ids.ndim == 2 and sampled_token_ids.shape[-1] != 1: + return sampled_token_ids + + updated_token_ids = sampled_token_ids.clone() + audio_bos_id = int(self.config.text_config.audio_bos_index) + audio_eos_id = int(self.config.text_config.audio_eos_index) + + for rid in req_ids: + req_buffer = model_intermediate_buffer.get(rid) + if not isinstance(req_buffer, dict): + continue + + idx = req_id_to_index.get(rid) + if idx is None: + continue + + token_slot = updated_token_ids[idx] if updated_token_ids.ndim == 1 else updated_token_ids[idx, 0] + original_token_id = int(token_slot.item()) + sampled_token_id = original_token_id + raw_text_token_id = req_buffer.get(_RAW_TEXT_TOKEN_ID_KEY) + if self.sp_gen_kwargs["text_greedy"] and raw_text_token_id is not None: + sampled_token_id = int(raw_text_token_id) + speech_active = bool(req_buffer.get(_GENERATE_SPEECH_KEY, False)) + force_audio_bos_pending = bool(req_buffer.get(_FORCE_AUDIO_BOS_KEY, False)) + finish_speech = bool(req_buffer.pop(_FINISH_SPEECH_KEY, False)) + + final_token_id, next_speech_active, next_force_audio_bos_pending = self._resolve_next_speech_state( + sampled_token_id=sampled_token_id, + generate_speech=speech_active, + finish_speech=finish_speech, + force_audio_bos_pending=force_audio_bos_pending, + audio_bos_id=audio_bos_id, + audio_eos_id=audio_eos_id, + ) + + if final_token_id != original_token_id: + token_slot.fill_(final_token_id) + + req_buffer[_GENERATE_SPEECH_KEY] = next_speech_active + req_buffer[_FORCE_AUDIO_BOS_KEY] = next_force_audio_bos_pending + + return updated_token_ids + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/model_executor/models/funaudiochat/funaudiochat_code2wav.py b/vllm_omni/model_executor/models/funaudiochat/funaudiochat_code2wav.py new file mode 100644 index 00000000000..06bff4a380d --- /dev/null +++ b/vllm_omni/model_executor/models/funaudiochat/funaudiochat_code2wav.py @@ -0,0 +1,453 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import os +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from torch.nn import functional as F +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from vllm_omni.model_executor.models.cosyvoice3.cosyvoice3_code2wav import CosyVoice3Code2Wav +from vllm_omni.model_executor.models.cosyvoice3.utils import make_pad_mask +from vllm_omni.model_executor.models.funaudiochat.common import resolve_funaudiochat_root +from vllm_omni.model_executor.models.output_templates import OmniOutput + +logger = init_logger(__name__) +_OFFICIAL_TOKEN_HOP_LEN = 25 * 30 +_OFFICIAL_MIN_SEGMENT_TOKENS = 50 + + +class FunAudioChatCosyVoice3Code2Wav(nn.Module): + input_modalities = "audio" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + del prefix + super().__init__() + self.vllm_config = vllm_config + self.model_path = self._resolve_model_path(vllm_config.model_config.model) + self.have_multimodal_outputs = True + self.enable_update_additional_information = False + self.requires_raw_input_tokens = True + self.hf_config_path = getattr(vllm_config.model_config, "hf_config_path", None) + + from transformers import AutoConfig + + from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config + + try: + AutoConfig.register(CosyVoice3Config.model_type, CosyVoice3Config) + except ValueError: + pass + + config_source = self.hf_config_path or self.model_path + self.config = AutoConfig.from_pretrained(config_source, trust_remote_code=True) + self.code2wav = CosyVoice3Code2Wav(self.config) + # Keep FunAudioChat's stage-1 flow stack in float32 to match the + # official runtime without changing global CosyVoice3 behavior. + self.code2wav.flow_model = self.code2wav.flow_model.float() + device = vllm_config.device_config.device + self.code2wav.load_weights(self.model_path, device=device) + estimator = getattr(self.code2wav.decoder, "estimator", None) + if estimator is not None and hasattr(estimator, "static_chunk_size"): + estimator.static_chunk_size = 2 * _OFFICIAL_TOKEN_HOP_LEN + self._speaker_embedding = self._load_default_speaker_embedding() + self._max_codec_token_id = int(self.config.flow["vocab_size"]) - 1 + self._max_supported_token_len = self._compute_max_supported_token_len() + self._dummy_profile_token_len = min(32, self._max_supported_token_len) + self._logged_dummy_profile_cap = False + + def _resolve_model_path(self, model_path: str) -> str: + local_path = Path(model_path) + if local_path.exists(): + return str(local_path) + + logger.info("Resolving FunAudioChat CosyVoice3 weights to a local snapshot: %s", model_path) + return snapshot_download(model_path) + + def _load_default_speaker_embedding(self) -> torch.Tensor: + env_path = os.environ.get("FUN_AUDIO_CHAT_SPK_INFO") + if env_path: + spk_path = Path(env_path).expanduser() + else: + spk_path = resolve_funaudiochat_root() / "utils" / "new_spk2info.pt" + if not spk_path.exists(): + raise FileNotFoundError( + f"Default speaker embedding not found: {spk_path}. " + "Set FUN_AUDIO_CHAT_SPK_INFO or install Fun-Audio-Chat from source." + ) + spk_info = torch.load(spk_path, map_location="cpu") + return spk_info["中文女"]["embedding"].reshape(1, -1).float() + + def _compute_max_supported_token_len(self) -> int: + max_audio_samples = 300 * int(self.config.hift["sampling_rate"]) + sine_waves = getattr(self.code2wav.hift.m_source, "sine_waves", None) + if isinstance(sine_waves, torch.Tensor) and sine_waves.ndim >= 2: + max_audio_samples = int(sine_waves.shape[1]) + samples_per_mel = int(self.config.hift["istft_params"]["hop_len"]) + for rate in self.config.hift["upsample_rates"]: + samples_per_mel *= int(rate) + samples_per_token = int(self.config.flow["token_mel_ratio"]) * samples_per_mel + return max_audio_samples // samples_per_token + + @staticmethod + def _get_prompt_token_id_batches(sampling_metadata: Any) -> list[torch.Tensor] | None: + prompt_token_ids = getattr(sampling_metadata, "prompt_token_ids", None) + if prompt_token_ids is None: + return None + + if isinstance(prompt_token_ids, torch.Tensor): + prompt_token_ids = prompt_token_ids.detach().to(torch.long) + if prompt_token_ids.ndim <= 1: + return [prompt_token_ids.view(-1)] + return [row.reshape(-1) for row in prompt_token_ids] + + if isinstance(prompt_token_ids, list): + if len(prompt_token_ids) == 0: + return None + if isinstance(prompt_token_ids[0], (list, tuple, torch.Tensor)): + batches = [torch.as_tensor(item, dtype=torch.long).reshape(-1) for item in prompt_token_ids] + else: + batches = [torch.tensor(prompt_token_ids, dtype=torch.long)] + return batches or None + + return None + + @staticmethod + def _split_request_ids(ids: torch.Tensor, seq_token_counts: list[int] | None = None) -> list[torch.Tensor]: + if seq_token_counts is not None and len(seq_token_counts) > 1: + boundaries = [0] + for count in seq_token_counts: + boundaries.append(boundaries[-1] + count) + n = ids.numel() + return [ids[boundaries[i] : min(boundaries[i + 1], n)] for i in range(len(seq_token_counts))] + return [ids] + + def _build_decode_tokens( + self, + input_ids: torch.Tensor, + sampling_metadata: Any, + seq_token_counts: list[int] | None = None, + ) -> tuple[list[torch.Tensor], bool]: + prompt_token_id_batches = self._get_prompt_token_id_batches(sampling_metadata) + if prompt_token_id_batches is not None: + raw_id_batches = prompt_token_id_batches + elif input_ids is not None: + raw_id_batches = self._split_request_ids(input_ids.reshape(-1), seq_token_counts) + else: + raw_id_batches = [torch.empty((0,), dtype=torch.long)] + + token_batches = [ + raw_ids.reshape(1, -1).to(dtype=torch.long, device=self.vllm_config.device_config.device).clamp_( + min=0, + max=self._max_codec_token_id, + ) + for raw_ids in raw_id_batches + ] + + is_dummy_profile = bool( + sampling_metadata is None + and prompt_token_id_batches is None + and len(token_batches) == 1 + and (token_batches[0].numel() == 0 or torch.count_nonzero(token_batches[0]).item() == 0) + ) + if is_dummy_profile and token_batches[0].shape[1] > self._dummy_profile_token_len: + if not self._logged_dummy_profile_cap: + logger.info( + "FunAudioChat code2wav dummy/profile run detected. Capping decode length from %d to %d tokens.", + token_batches[0].shape[1], + self._dummy_profile_token_len, + ) + self._logged_dummy_profile_cap = True + token_batches[0] = token_batches[0][:, : self._dummy_profile_token_len] + + return token_batches, is_dummy_profile + + @staticmethod + def _split_tokens_like_official(token: torch.Tensor) -> list[torch.Tensor]: + flat = token.reshape(-1) + if flat.numel() == 0: + return [flat] + + segments: list[torch.Tensor] = [] + time_step = 0 + while time_step * 25 < flat.numel(): + start = time_step * 25 + end = min((time_step + 30) * 25, flat.numel()) + segments.append(flat[start:end]) + time_step += 30 + + if len(segments) > 1 and segments[-1].numel() < _OFFICIAL_MIN_SEGMENT_TOKENS: + merged = torch.cat([segments[-2], segments[-1]], dim=0) + split_point = merged.numel() // 2 + segments = [*segments[:-2], merged[:split_point], merged[split_point:]] + + return segments + + @staticmethod + def _fade_in_out(fade_in_tensor: torch.Tensor, fade_out_tensor: torch.Tensor, window: Any) -> torch.Tensor: + if fade_in_tensor.numel() == 0 or fade_out_tensor.numel() == 0: + return fade_in_tensor + + overlap = min(int(len(window) // 2), fade_in_tensor.shape[-1], fade_out_tensor.shape[-1]) + if overlap <= 0: + return fade_in_tensor + + fade_window = torch.as_tensor(window, device=fade_in_tensor.device, dtype=fade_in_tensor.dtype) + mixed = fade_in_tensor.clone() + mixed[..., :overlap] = ( + mixed[..., :overlap] * fade_window[:overlap] + fade_out_tensor[..., -overlap:] * fade_window[-overlap:] + ) + return mixed + + def _run_flow_like_official( + self, + token: torch.Tensor, + prompt_token: torch.Tensor, + prompt_feat: torch.Tensor, + embedding: torch.Tensor, + *, + finalize: bool, + flow_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + flow_model = self.code2wav.flow_model + device = token.device + token = token.to(device=device, dtype=torch.long) + prompt_token = prompt_token.to(device=device, dtype=torch.long) + prompt_feat = prompt_feat.to(device=device, dtype=torch.float32) + embedding = embedding.to(device=device, dtype=torch.float32) + embedding = F.normalize(embedding, dim=1) + embedding = flow_model.spk_embed_affine_layer(embedding) + + token_len = torch.tensor([token.shape[1]], dtype=torch.int32, device=device) + prompt_token_len = torch.tensor([prompt_token.shape[1]], dtype=torch.int32, device=device) + full_token = torch.cat([prompt_token, token], dim=1) + full_token_len = prompt_token_len + token_len + mask = (~make_pad_mask(full_token_len)).unsqueeze(-1).to(embedding) + token_emb = flow_model.input_embedding(torch.clamp(full_token, min=0)) * mask + + if finalize: + h = flow_model.pre_lookahead_layer(token_emb) + else: + h = flow_model.pre_lookahead_layer( + token_emb[:, : -flow_model.pre_lookahead_len], + context=token_emb[:, -flow_model.pre_lookahead_len :], + ) + h = h.repeat_interleave(flow_model.token_mel_ratio, dim=1) + + mel_len1 = prompt_feat.shape[1] + mel_len2 = h.shape[1] - mel_len1 + conds = torch.zeros( + [1, mel_len1 + mel_len2, flow_model.output_size], + device=device, + dtype=h.dtype, + ) + conds[:, :mel_len1] = prompt_feat + conds = conds.transpose(1, 2) + + mel_mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2], device=device))).to(h) + decoder_kwargs = { + "mu": h.transpose(1, 2).contiguous(), + "mask": mel_mask.unsqueeze(1), + "spks": embedding, + "cond": conds, + "n_timesteps": 10, + } + try: + decoder_out = flow_model.decoder(cache=flow_cache, **decoder_kwargs) + except TypeError as exc: + if "cache" not in str(exc): + raise + decoder_out = flow_model.decoder(**decoder_kwargs) + + if isinstance(decoder_out, tuple): + feat, next_flow_cache = decoder_out + else: + feat, next_flow_cache = decoder_out, flow_cache + feat = feat[:, :, mel_len1:] + return feat.float(), next_flow_cache + + def _run_hift_like_official( + self, + speech_feat: torch.Tensor, + *, + finalize: bool, + cache_source: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + hift = self.code2wav.hift + speech_feat = speech_feat.to(dtype=torch.float32) + hift.f0_predictor.to("cpu") + f0 = hift.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat) + source = hift.f0_upsamp(f0[:, None]).transpose(1, 2) + source, _, _ = hift.m_source(source) + source = source.transpose(1, 2) + if cache_source.shape[2] != 0: + source[:, :, : cache_source.shape[2]] = cache_source + + if finalize: + speech = hift.decode(x=speech_feat, s=source, finalize=True) + else: + padding = hift.f0_predictor.condnet[0].causal_padding + speech = hift.decode(x=speech_feat[:, :, :-padding], s=source, finalize=False) + return speech, source + + def _decode_segment_like_official( + self, + token_segment: torch.Tensor, + prompt_token: torch.Tensor, + prompt_feat: torch.Tensor, + embedding: torch.Tensor, + ) -> torch.Tensor: + if token_segment.numel() == 0: + return torch.zeros((0,), device=embedding.device, dtype=torch.float32) + + device = token_segment.device + flow_cache = torch.zeros((1, 80, 0, 2), device=device, dtype=torch.float32) + mel_overlap = torch.zeros((1, self.code2wav.output_size, 0), device=device, dtype=torch.float32) + hift_cache: dict[str, torch.Tensor] | None = None + pre_lookahead_len = int(self.config.flow["pre_lookahead_len"]) + token_offset = 0 + speech_chunks: list[torch.Tensor] = [] + + while token_offset < token_segment.numel(): + chunk_len = min(token_offset + _OFFICIAL_TOKEN_HOP_LEN + pre_lookahead_len, token_segment.numel()) + chunk = token_segment[:chunk_len].reshape(1, -1) + finalize = chunk.shape[1] == token_segment.numel() + tts_mel, flow_cache = self._run_flow_like_official( + chunk, + prompt_token, + prompt_feat, + embedding, + finalize=finalize, + flow_cache=flow_cache, + ) + if mel_overlap.shape[2] != 0: + tts_mel = self._fade_in_out(tts_mel, mel_overlap, self.code2wav.mel_window) + + if hift_cache is not None: + cache_source = hift_cache["source"] + tts_mel = torch.cat([hift_cache["mel"], tts_mel], dim=2) + else: + cache_source = torch.zeros((1, 1, 0), device=device, dtype=tts_mel.dtype) + + if not finalize: + mel_overlap = tts_mel[:, :, -self.code2wav.mel_overlap_len :] + tts_mel = tts_mel[:, :, : -self.code2wav.mel_overlap_len] + if tts_mel.shape[2] == 0: + token_offset += _OFFICIAL_TOKEN_HOP_LEN + continue + tts_speech, tts_source = self._run_hift_like_official( + tts_mel, + finalize=False, + cache_source=cache_source, + ) + if hift_cache is not None: + tts_speech = self._fade_in_out(tts_speech, hift_cache["speech"], self.code2wav.speech_window) + hift_cache = { + "mel": tts_mel[:, :, -self.code2wav.mel_cache_len :], + "source": tts_source[:, :, -self.code2wav.source_cache_len :], + "speech": tts_speech[:, -self.code2wav.source_cache_len :], + } + if tts_speech.shape[1] > self.code2wav.source_cache_len: + tts_speech = tts_speech[:, : -self.code2wav.source_cache_len] + else: + tts_speech = tts_speech[:, :0] + else: + tts_speech, _ = self._run_hift_like_official( + tts_mel, + finalize=True, + cache_source=cache_source, + ) + if hift_cache is not None: + tts_speech = self._fade_in_out(tts_speech, hift_cache["speech"], self.code2wav.speech_window) + + if tts_speech.numel() > 0: + speech_chunks.append(tts_speech.reshape(-1)) + + token_offset += _OFFICIAL_TOKEN_HOP_LEN + + if not speech_chunks: + return torch.zeros((0,), device=device, dtype=torch.float32) + return torch.cat(speech_chunks, dim=0) + + def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: + if input_ids is None or input_ids.numel() == 0: + return torch.empty((0, 1), dtype=torch.float32) + return torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.float32) + + def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> None: + del hidden_states, sampling_metadata + return None + + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: Any = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, + ) -> OmniOutput: + del positions, intermediate_tensors, inputs_embeds + + sampling_metadata = kwargs.get("sampling_metadata") + token_batches, is_dummy_profile = self._build_decode_tokens( + input_ids, + sampling_metadata, + kwargs.get("seq_token_counts"), + ) + num_reqs = len(token_batches) + empty = torch.zeros((0,), dtype=torch.float32) + sr = torch.tensor(24000, dtype=torch.int32) + if not token_batches or all(token.numel() == 0 for token in token_batches): + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"audio": [empty] * max(num_reqs, 1), "sr": [sr] * max(num_reqs, 1)}, + ) + + if is_dummy_profile: + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"audio": [empty.to(device=token_batches[0].device)], "sr": [sr]}, + ) + + audios: list[torch.Tensor] = [] + srs: list[torch.Tensor] = [] + for token in token_batches: + if token.numel() == 0: + audios.append(empty) + srs.append(sr) + continue + prompt_token = torch.zeros((1, 0), dtype=torch.long, device=token.device) + prompt_feat = torch.zeros((1, 0, 80), dtype=torch.float32, device=token.device) + embedding = self._speaker_embedding.to(device=token.device, dtype=torch.float32) + audio_segments: list[torch.Tensor] = [] + for token_segment in self._split_tokens_like_official(token): + if token_segment.numel() == 0: + continue + segment_audio = self._decode_segment_like_official( + token_segment, + prompt_token, + prompt_feat, + embedding, + ) + if segment_audio.numel() > 0: + audio_segments.append(segment_audio.reshape(-1)) + audio = torch.cat(audio_segments, dim=0) if audio_segments else torch.zeros((0,), device=token.device) + audios.append(audio.reshape(-1).detach().cpu()) + srs.append(sr) + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"audio": audios, "sr": srs}, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + del weights + # All parameters are loaded eagerly from the local snapshot in `__init__`. + return {name for name, _ in self.named_parameters()} diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index f708b6ff904..058111f9b83 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -57,6 +57,16 @@ "cosyvoice3", "CosyVoice3Model", ), + "FunAudioChatForConditionalGeneration": ( + "funaudiochat", + "funaudiochat", + "FunAudioChatForConditionalGeneration", + ), + "FunAudioChatCosyVoice3Code2Wav": ( + "funaudiochat", + "funaudiochat_code2wav", + "FunAudioChatCosyVoice3Code2Wav", + ), "MammothModa2Qwen2ForCausalLM": ( "mammoth_moda2", "mammoth_moda2", diff --git a/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml b/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml new file mode 100644 index 00000000000..aca184567f0 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml @@ -0,0 +1,66 @@ +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model: FunAudioLLM/Fun-Audio-Chat-8B + model_stage: s2s + model_arch: FunAudioChatForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + trust_remote_code: true + enable_prefix_caching: false + gpu_memory_utilization: 0.55 + engine_output_type: latent + max_model_len: 8192 + max_num_batched_tokens: 8192 + enforce_eager: true + distributed_executor_backend: "mp" + dtype: bfloat16 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.8 + top_p: 0.9 + top_k: 0 + repetition_penalty: 1.2 + max_tokens: 2048 + stop_token_ids: [151645] + detokenize: true + seed: 42 + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model: FunAudioLLM/Fun-CosyVoice3-0.5B-2512 + model_stage: code2wav + model_arch: FunAudioChatCosyVoice3Code2Wav + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + trust_remote_code: true + enable_prefix_caching: false + gpu_memory_utilization: 0.25 + engine_output_type: audio + max_model_len: 32768 + max_num_batched_tokens: 32768 + skip_tokenizer_init: true + enforce_eager: true + distributed_executor_backend: "mp" + dtype: bfloat16 + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.funaudiochat.funaudiochat2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 1 + detokenize: true + seed: 42 diff --git a/vllm_omni/model_executor/stage_input_processors/funaudiochat.py b/vllm_omni/model_executor/stage_input_processors/funaudiochat.py new file mode 100644 index 00000000000..5a179b063bd --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/funaudiochat.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any + +import torch +from vllm.logger import init_logger + +from vllm_omni.inputs.data import OmniTokensPrompt + +_MAX_COSYVOICE_TOKEN_ID = 6561 +logger = init_logger(__name__) + + +def _validate_stage_inputs(stage_list: list[Any], engine_input_source: list[int]) -> Any: + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + source_stage_id = engine_input_source[0] + if source_stage_id >= len(stage_list): + raise IndexError(f"Invalid stage_id: {source_stage_id}") + + stage_outputs = stage_list[source_stage_id].engine_outputs + if stage_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + return stage_outputs + + +def _to_flat_audio_token_ids(audio_token_ids: Any) -> torch.Tensor: + if not isinstance(audio_token_ids, torch.Tensor): + audio_token_ids = torch.as_tensor(audio_token_ids, dtype=torch.long) + audio_token_ids = audio_token_ids.to(dtype=torch.long) + if audio_token_ids.ndim == 2: + # Token id 0 is valid for code2wav. Only drop rows that are fully negative + # placeholders, and preserve all-zero codec groups from stage-0. + valid_rows = (audio_token_ids >= 0).any(dim=-1) + audio_token_ids = audio_token_ids[valid_rows] + return audio_token_ids.reshape(-1) + + +def funaudiochat2code2wav( + stage_list: list[Any], + engine_input_source: list[int], + prompt: Any = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Convert FunAudioChat stage-0 audio codec output into code2wav prompts.""" + del prompt, requires_multimodal_data + + stage_outputs = _validate_stage_inputs(stage_list, engine_input_source) + code2wav_inputs: list[OmniTokensPrompt] = [] + for stage_output in stage_outputs: + output = stage_output.outputs[0] + mm_output = getattr(output, "multimodal_output", None) or {} + audio_token_ids = mm_output.get("audio_token_ids") + if audio_token_ids is None: + audio_token_ids = mm_output.get("speech_ids") + if audio_token_ids is None: + raise ValueError("Stage-0 FunAudioChat output does not contain `speech_ids` or `audio_token_ids`.") + flat_audio_token_ids = _to_flat_audio_token_ids(audio_token_ids) + filtered = flat_audio_token_ids[(flat_audio_token_ids >= 0) & (flat_audio_token_ids < _MAX_COSYVOICE_TOKEN_ID)] + raw_min = int(flat_audio_token_ids.min().item()) if flat_audio_token_ids.numel() > 0 else None + raw_max = int(flat_audio_token_ids.max().item()) if flat_audio_token_ids.numel() > 0 else None + logger.info( + "FunAudioChat stage0->stage1 audio tokens: raw_len=%d filtered_len=%d raw_min=%s raw_max=%s tail=%s", + flat_audio_token_ids.numel(), + filtered.numel(), + raw_min, + raw_max, + flat_audio_token_ids[-8:].tolist() if flat_audio_token_ids.numel() > 0 else [], + ) + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=filtered.to(dtype=torch.long).reshape(-1).tolist(), + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + return code2wav_inputs diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 6f035a81ec7..3e16290ff69 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -145,6 +145,13 @@ def multimodal_output(self) -> dict[str, Any]: return mm return self._multimodal_output + def _request_outputs_list(self) -> list[RequestOutput]: + if self.request_output is None: + return [] + if isinstance(self.request_output, list): + return list(self.request_output) + return [self.request_output] + @property def custom_output(self) -> dict[str, Any]: """Return custom output data from diffusion pipelines. @@ -176,8 +183,9 @@ def prompt_token_ids(self) -> list[int] | None: This property is required for compatibility with vLLM's streaming chat completion generator which checks res.prompt_token_ids. """ - if self.request_output is not None: - return getattr(self.request_output, "prompt_token_ids", None) + request_outputs = self._request_outputs_list() + if request_outputs: + return getattr(request_outputs[0], "prompt_token_ids", None) return None @property @@ -187,36 +195,41 @@ def outputs(self) -> list[Any]: This property is required for compatibility with vLLM's streaming and non-streaming chat completion generators. """ - if self.request_output is not None: - return getattr(self.request_output, "outputs", []) - return [] + outputs: list[Any] = [] + for req_out in self._request_outputs_list(): + outputs.extend(getattr(req_out, "outputs", []) or []) + return outputs @property def encoder_prompt_token_ids(self) -> list[int] | None: """Return encoder prompt token IDs from the underlying request output.""" - if self.request_output is not None: - return getattr(self.request_output, "encoder_prompt_token_ids", None) + request_outputs = self._request_outputs_list() + if request_outputs: + return getattr(request_outputs[0], "encoder_prompt_token_ids", None) return None @property def prompt_logprobs(self) -> Any: """Return prompt logprobs from the underlying request output.""" - if self.request_output is not None: - return getattr(self.request_output, "prompt_logprobs", None) + request_outputs = self._request_outputs_list() + if request_outputs: + return getattr(request_outputs[0], "prompt_logprobs", None) return None @property def num_cached_tokens(self) -> int | None: """Return number of cached tokens from the underlying request output.""" - if self.request_output is not None: - return getattr(self.request_output, "num_cached_tokens", None) + request_outputs = self._request_outputs_list() + if request_outputs: + return getattr(request_outputs[0], "num_cached_tokens", None) return None @property def kv_transfer_params(self) -> Any: """Return KV transfer params from the underlying request output.""" - if self.request_output is not None: - return getattr(self.request_output, "kv_transfer_params", None) + request_outputs = self._request_outputs_list() + if request_outputs: + return getattr(request_outputs[0], "kv_transfer_params", None) return None @property diff --git a/vllm_omni/transformers_utils/configs/__init__.py b/vllm_omni/transformers_utils/configs/__init__.py index 59b23f91490..eaf000c23ff 100644 --- a/vllm_omni/transformers_utils/configs/__init__.py +++ b/vllm_omni/transformers_utils/configs/__init__.py @@ -10,6 +10,8 @@ import importlib _CLASS_TO_MODULE: dict[str, str] = { + "FunAudioChatAudioEncoderConfig": "vllm_omni.transformers_utils.configs.funaudiochat", + "FunAudioChatConfig": "vllm_omni.transformers_utils.configs.funaudiochat", "Mammothmoda2Config": "vllm_omni.transformers_utils.configs.mammoth_moda2", "Mammothmoda2Qwen2_5_VLConfig": "vllm_omni.transformers_utils.configs.mammoth_moda2", "Mammothmoda2Qwen2_5_VLTextConfig": "vllm_omni.transformers_utils.configs.mammoth_moda2", @@ -20,6 +22,8 @@ } __all__ = [ + "FunAudioChatAudioEncoderConfig", + "FunAudioChatConfig", "Mammothmoda2Config", "Mammothmoda2Qwen2_5_VLConfig", "Mammothmoda2Qwen2_5_VLTextConfig", @@ -46,4 +50,5 @@ def __dir__(): # Eagerly import all config modules so their AutoConfig.register() side-effects # run as soon as `vllm_omni.transformers_utils.configs` is imported. from vllm_omni.transformers_utils.configs import fish_speech as _fish_speech # noqa: F401, E402 +from vllm_omni.transformers_utils.configs import funaudiochat as _funaudiochat # noqa: F401, E402 from vllm_omni.transformers_utils.configs import mammoth_moda2 as _mammoth_moda2 # noqa: F401, E402 diff --git a/vllm_omni/transformers_utils/configs/funaudiochat.py b/vllm_omni/transformers_utils/configs/funaudiochat.py new file mode 100644 index 00000000000..3e51cc67d20 --- /dev/null +++ b/vllm_omni/transformers_utils/configs/funaudiochat.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FunAudioChat config registration. + +The public Fun-Audio-Chat checkpoints use ``model_type="funaudiochat"``. +Depending on the runtime environment, that type may not be known to the +installed Transformers build. Register it eagerly so vLLM-Omni can load the +checkpoint config before the model wrapper is instantiated. +""" + +from __future__ import annotations + +from transformers import AutoConfig, PretrainedConfig + +try: + from funaudiochat.configuration_funaudiochat import ( # type: ignore + FunAudioChatAudioEncoderConfig, + FunAudioChatConfig, + ) +except ImportError: + + class FunAudioChatAudioEncoderConfig(PretrainedConfig): + model_type = "funaudiochat_audio_encoder" + + def __init__( + self, + _attn_implementation: str | None = None, + num_mel_bins: int = 128, + encoder_layers: int = 32, + encoder_attention_heads: int = 20, + encoder_ffn_dim: int = 5120, + d_model: int = 1280, + dropout: float = 0.0, + attention_dropout: float = 0.0, + activation_function: str = "gelu", + activation_dropout: float = 0.0, + scale_embedding: bool = False, + initializer_range: float = 0.02, + max_source_positions: int = 1500, + n_window: int = 100, + output_dim: int = 3584, + bos_token_id: int | None = None, + codebook_size: int | None = None, + continuous_features_mode: str = "replace", + crq_transformer_config: dict | None = None, + eos_token_id: int | None = None, + group_size: int = 5, + enable_audio_invert_tower: bool = True, + pad_token_id: int | None = None, + **kwargs, + ) -> None: + attn_impl = kwargs.pop("_attn_implementation", None) or _attn_implementation + super().__init__(**kwargs) + self._attn_implementation = attn_impl or "sdpa" + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.num_hidden_layers = encoder_layers + self.initializer_range = initializer_range + self.scale_embedding = scale_embedding + self.max_source_positions = max_source_positions + self.n_window = n_window + self.output_dim = output_dim + self.bos_token_id = bos_token_id + self.codebook_size = codebook_size + self.continuous_features_mode = continuous_features_mode + self.crq_transformer_config = crq_transformer_config + self.eos_token_id = eos_token_id + self.group_size = group_size + self.enable_audio_invert_tower = enable_audio_invert_tower + self.pad_token_id = pad_token_id + + class FunAudioChatConfig(PretrainedConfig): + model_type = "funaudiochat" + attribute_map = {"audio_token_id": "audio_token_index"} + + def __init__( + self, + audio_config: PretrainedConfig | dict | None = None, + text_config: PretrainedConfig | dict | None = None, + audio_token_index: int = 151646, + ignore_index: int = -100, + hidden_size: int | None = None, + **kwargs, + ) -> None: + import transformers + + self.audio_token_index = audio_token_index + self.ignore_index = ignore_index + + if isinstance(audio_config, dict): + audio_config.setdefault("model_type", FunAudioChatAudioEncoderConfig.model_type) + audio_config = FunAudioChatAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = FunAudioChatAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + text_config.setdefault("model_type", "qwen2") + text_cls = transformers.CONFIG_MAPPING[text_config["model_type"]] + text_config = text_cls(**text_config) + elif text_config is None: + text_config = transformers.CONFIG_MAPPING["qwen2"]() + self.text_config = text_config + + self.hidden_size = int(self.text_config.hidden_size) if hidden_size is None else int(hidden_size) + super().__init__(**kwargs) + + +AutoConfig.register(FunAudioChatAudioEncoderConfig.model_type, FunAudioChatAudioEncoderConfig, exist_ok=True) +AutoConfig.register(FunAudioChatConfig.model_type, FunAudioChatConfig, exist_ok=True) + +__all__ = ["FunAudioChatAudioEncoderConfig", "FunAudioChatConfig"] diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index d7d45031af4..ad9a583edd9 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -21,7 +21,10 @@ RoutedExpertsCapturer, ) from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput -from vllm.v1.outputs import AsyncModelRunnerOutput, make_empty_encoder_model_runner_output +from vllm.v1.outputs import ( + AsyncModelRunnerOutput, + make_empty_encoder_model_runner_output, +) from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer @@ -276,9 +279,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - # When spec decode is enabled, delay clearing connector metadata - # until after draft model runs in sample_tokens. - clear_kv_metadata = self.speculative_config is None + # When spec decode is enabled, delay connector finalization until + # after draft model runs in sample_tokens. + defer_finalize = self.speculative_config is not None with ( set_forward_context( attn_metadata, @@ -291,9 +294,7 @@ def execute_model( slot_mapping=slot_mappings, # OMNI: required for KV cache operations ), record_function_or_nullcontext("gpu_model_runner: forward"), - self.maybe_get_kv_connector_output( - scheduler_output, clear_metadata=clear_kv_metadata - ) as kv_connector_output, + self.maybe_get_kv_connector_output(scheduler_output, defer_finalize=defer_finalize) as kv_connector_output, ): model_output = self._model_forward( input_ids=input_ids, @@ -398,6 +399,24 @@ def execute_model( return None + def _postprocess_sampled_token_ids( + self, + sampled_token_ids: torch.Tensor, + *, + model_intermediate_buffer: dict[str, dict[str, Any]] | None = None, + ) -> torch.Tensor: + postprocess = getattr(self.model, "postprocess_sampled_tokens", None) + if postprocess is None or sampled_token_ids.numel() == 0: + return sampled_token_ids + buffer_map = self.model_intermediate_buffer if model_intermediate_buffer is None else model_intermediate_buffer + corrected_sampled_token_ids = postprocess( + sampled_token_ids=sampled_token_ids, + req_ids=self.input_batch.req_ids.copy(), + req_id_to_index=self.input_batch.req_id_to_index.copy(), + model_intermediate_buffer=buffer_map, + ) + return sampled_token_ids if corrected_sampled_token_ids is None else corrected_sampled_token_ids + @torch.inference_mode() def sample_tokens( self, @@ -453,6 +472,31 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None) + if num_scheduled_tokens_np is None: + req_ids = self.input_batch.req_ids + num_scheduled_tokens_np = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], + dtype=np.int32, + ) + + pending_intermediate_updates = self._collect_additional_information_updates( + hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output + ) + overlay_intermediate_buffer = None + postprocess_hook = getattr(self.model, "postprocess_sampled_tokens", None) + if pending_intermediate_updates or postprocess_hook is not None: + overlay_intermediate_buffer = self._build_overlay_intermediate_buffer( + self.input_batch.req_ids.copy(), + pending_intermediate_updates, + ) + + with record_function_or_nullcontext("gpu_model_runner: postprocess_sampled_tokens"): + sampler_output.sampled_token_ids = self._postprocess_sampled_token_ids( + sampler_output.sampled_token_ids, + model_intermediate_buffer=overlay_intermediate_buffer, + ) + self._draft_token_ids = None self._draft_token_req_ids = None self.input_batch.prev_sampled_token_ids = None @@ -528,16 +572,23 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, ) + if overlay_intermediate_buffer is not None: + self._commit_intermediate_buffer_overlay( + overlay_intermediate_buffer, + invalid_req_indices=invalid_req_indices, + req_ids=self.input_batch.req_ids.copy(), + ) + if propose_drafts_after_bookkeeping: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - # Clear KV connector metadata after draft model runs (if spec decode). + # Finalize KV connector after draft model runs (if spec decode). # This was deferred from target model forward to allow draft model # to also save its KV cache. if self.speculative_config is not None: - self.clear_kv_connector_metadata() + self.finalize_kv_connector() with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() @@ -547,19 +598,10 @@ def propose_draft_token_ids(sampled_token_ids): self.kv_connector_output = None hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() - num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None) - if num_scheduled_tokens_np is None: - req_ids = self.input_batch.req_ids - num_scheduled_tokens_np = np.array( - [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], - dtype=np.int32, - ) - - self._process_additional_information_updates( - hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output - ) pooler_output: list[dict[str, object]] = [] + buffer_payload_keys = tuple(getattr(self.model, "pooler_output_buffer_keys", ()) or ()) + pooler_buffer_source = overlay_intermediate_buffer or self.model_intermediate_buffer for rid in req_ids_output_copy: idx = req_id_to_index_output_copy[rid] start = int(self.query_start_loc.cpu[idx]) @@ -589,6 +631,27 @@ def propose_draft_token_ids(sampled_token_ids): logger.error(f"Error in merge multimodal outputs: {e}") if mm_payload: payload.update(mm_payload) + if buffer_payload_keys: + req_buffer = pooler_buffer_source.get(rid, {}) + if isinstance(req_buffer, dict): + for key in buffer_payload_keys: + if key not in req_buffer: + continue + value = req_buffer[key] + if isinstance(value, torch.Tensor): + payload[key] = value.detach().to("cpu").contiguous() + elif isinstance(value, list): + payload[key] = [ + item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item + for item in value + ] + elif isinstance(value, tuple): + payload[key] = tuple( + item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item + for item in value + ) + else: + payload[key] = value pooler_output.append(payload) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): if self.model_config.enable_return_routed_experts: diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 05785d7aa6c..7c9de942fd1 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -265,9 +265,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - # When spec decode is enabled, delay clearing connector metadata - # until after draft model runs in sample_tokens. - clear_kv_metadata = self.speculative_config is None + # When spec decode is enabled, delay connector finalization until + # after draft model runs in sample_tokens. + defer_finalize = self.speculative_config is not None with ( set_forward_context( attn_metadata, @@ -280,9 +280,7 @@ def execute_model( slot_mapping=slot_mappings, # OMNI: required for KV cache operations ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output( - scheduler_output, clear_metadata=clear_kv_metadata - ) as kv_connector_output, + self.maybe_get_kv_connector_output(scheduler_output, defer_finalize=defer_finalize) as kv_connector_output, ): outputs = self._run_generation_model( input_ids=input_ids, @@ -351,9 +349,9 @@ def sample_tokens( ) = self.execute_model_state self.execute_model_state = None - # Clear KV connector metadata after draft model runs (if spec decode). + # Finalize KV connector after draft model runs (if spec decode). if self.speculative_config is not None: - self.clear_kv_connector_metadata() + self.finalize_kv_connector() pooler_output: list[object] = [] if isinstance(multimodal_outputs, torch.Tensor): diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index e17d52bdd53..c66fa7f7ee3 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1006,6 +1006,26 @@ def _process_additional_information_updates( scheduler_output: "SchedulerOutput", ) -> None: """Process model-provided per-request updates and merge into model_intermediate_buffer.""" + updates = self._collect_additional_information_updates( + hidden_states, + multimodal_outputs, + num_scheduled_tokens_np, + scheduler_output, + ) + if not updates: + return + overlay_buffer = self._build_overlay_intermediate_buffer(list(updates), updates) + self._commit_intermediate_buffer_overlay(overlay_buffer, req_ids=list(updates)) + + def _collect_additional_information_updates( + self, + hidden_states: torch.Tensor, + multimodal_outputs: object, + num_scheduled_tokens_np: np.ndarray, + scheduler_output: "SchedulerOutput", + ) -> dict[str, dict[str, Any]]: + """Collect model-provided per-request updates without mutating runtime state.""" + updates: dict[str, dict[str, Any]] = {} try: # execute the custom postprocess function # TODO(Peiqi): do we have a more elegant way to do this? @@ -1018,7 +1038,9 @@ def _process_additional_information_updates( # only consider to store data into update dict. hidden_states_slice = hidden_states[s:e] update_dict = self.model.postprocess(hidden_states_slice, **req_infos) - self._update_intermediate_buffer(req_id, update_dict) + normalized = self._normalize_intermediate_update(update_dict) + if normalized: + updates[req_id] = normalized except Exception as e: logger.error( f"Error merging for requests:{self.input_batch.req_ids} " @@ -1028,6 +1050,7 @@ def _process_additional_information_updates( import traceback traceback.print_exc() + return updates def _collect_additional_information_for_prefill( self, @@ -1158,9 +1181,15 @@ def _preprocess( model_kwargs = self._init_model_kwargs() input_ids = self.input_ids.gpu[:num_input_tokens] elif getattr(self.model, "has_preprocess", False): - # Use pre-allocated buffer for CUDA graph compatibility. + # Raw-token preprocess stages should see input_ids first, then + # materialize embeddings back into the pre-allocated buffer before + # the final forward. input_ids = self.input_ids.gpu[:num_input_tokens] - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + inputs_embeds = ( + None + if getattr(self.model, "requires_raw_input_tokens", False) + else self.inputs_embeds.gpu[:num_input_tokens] + ) model_kwargs = self._init_model_kwargs() else: # For text-only models, we use token ids as input. @@ -1220,6 +1249,7 @@ def _preprocess( # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only decode_req_ids = [] + preprocess_results: list[dict[str, Any]] = [] for req_index, req_id in enumerate(self.input_batch.req_ids): req_infos = self.model_intermediate_buffer.get(req_id, {}) @@ -1237,18 +1267,41 @@ def _preprocess( req_input_ids, req_embeds, update_dict = self.model.preprocess( input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos ) - if inputs_embeds is None: - inputs_embeds = torch.empty( - (input_ids.shape[0], req_embeds.shape[-1]), - device=req_embeds.device, - dtype=req_embeds.dtype, - ) + preprocess_results.append( + { + "req_id": req_id, + "start": s, + "end": e, + "span_len": span_len, + "req_input_ids": req_input_ids, + "req_embeds": req_embeds, + "update_dict": update_dict, + } + ) + + inputs_embeds, resolved_req_embeds = self._resolve_preprocess_batch_inputs_embeds( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + preprocess_results=preprocess_results, + ) + + for result, resolved_req_embeds_item in zip(preprocess_results, resolved_req_embeds, strict=False): + req_id = result["req_id"] + s = result["start"] + span_len = result["span_len"] + req_input_ids = result["req_input_ids"] + update_dict = result["update_dict"] if hasattr(self.model, "talker_mtp") and span_len == 1: + if resolved_req_embeds_item is None: + raise RuntimeError( + "talker_mtp requires preprocess embeddings for decode steps, " + f"but model.preprocess returned req_embeds=None for request {req_id}." + ) last_talker_hidden, text_step = update_dict.pop("mtp_inputs") decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) - self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) + self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(resolved_req_embeds_item) self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) self.text_step.gpu[decode_slice].copy_(text_step) decode_req_ids.append(req_id) @@ -1256,11 +1309,12 @@ def _preprocess( # TODO(Peiqi): the merge stage could move out from the critical path self._merge_additional_information_update(req_id, update_dict) - # update the inputs_embeds and input_ids - seg_len = min(span_len, req_embeds.shape[0]) - inputs_embeds[s : s + seg_len] = req_embeds[:seg_len] - if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: - input_ids[s : s + seg_len] = req_input_ids + if inputs_embeds is not None: + assert resolved_req_embeds_item is not None + seg_len = min(span_len, resolved_req_embeds_item.shape[0]) + inputs_embeds[s : s + seg_len] = resolved_req_embeds_item[:seg_len] + if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == span_len: + input_ids[s : s + span_len] = req_input_ids # run talker mtp decode if hasattr(self.model, "talker_mtp"): @@ -1334,31 +1388,131 @@ def _model_forward( return model_output def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: - if not isinstance(upd, dict) or not upd: + normalized = self._normalize_intermediate_update(upd) + if not normalized: return req_state = self.requests.get(req_id) if req_state is None: return - # Check if the model declares keys that should stay on GPU + # Preserve upstream GPU-resident buffer behavior for models that + # explicitly opt in, while keeping normalized CPU values elsewhere. gpu_keys: set[str] = set() if hasattr(self, "model") and hasattr(self.model, "gpu_resident_buffer_keys"): gpu_keys = self.model.gpu_resident_buffer_keys existing = self.model_intermediate_buffer.setdefault(req_id, {}) - for k, v in upd.items(): - if isinstance(v, torch.Tensor): - if k in gpu_keys: - existing[k] = v.detach().clone() - else: - existing[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, list): - existing[k] = [ - (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v - ] + for key, value in upd.items(): + if key in gpu_keys and isinstance(value, torch.Tensor): + existing[key] = value.detach().clone() else: - existing[k] = v + existing[key] = normalized[key] # Backward compatible: mirror to old setattr location setattr(req_state, "additional_information_cpu", existing) def _merge_additional_information_update(self, req_id, upd): logger.warning_once("_merge_additional_information_update is deprecated, use _update_intermediate_buffer") return self._update_intermediate_buffer(req_id, upd) + + @staticmethod + def _normalize_intermediate_update_value(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().to("cpu").contiguous() + if isinstance(value, list): + return [OmniGPUModelRunner._normalize_intermediate_update_value(item) for item in value] + if isinstance(value, tuple): + return tuple(OmniGPUModelRunner._normalize_intermediate_update_value(item) for item in value) + return value + + def _normalize_intermediate_update(self, upd: dict[str, Any] | None) -> dict[str, Any]: + if not isinstance(upd, dict) or not upd: + return {} + return {k: self._normalize_intermediate_update_value(v) for k, v in upd.items()} + + def _embed_input_ids_for_preprocess( + self, + input_ids: torch.Tensor, + *, + like: torch.Tensor | None = None, + ) -> torch.Tensor: + flat_input_ids = input_ids.reshape(-1) + + embed_input_ids = getattr(self.model, "embed_input_ids", None) + if callable(embed_input_ids): + req_embeds = embed_input_ids(input_ids=flat_input_ids) + else: + get_language_model = getattr(self.model, "get_language_model", None) + language_model = get_language_model() if callable(get_language_model) else None + lm_embed_input_ids = getattr(language_model, "embed_input_ids", None) + if callable(lm_embed_input_ids): + req_embeds = lm_embed_input_ids(flat_input_ids) + else: + get_input_embeddings = getattr(self.model, "get_input_embeddings", None) + if not callable(get_input_embeddings): + raise RuntimeError( + "Model preprocess returned req_embeds=None, but the runner " + "could not resolve a token embedding function." + ) + req_embeds = get_input_embeddings()(flat_input_ids) + + if req_embeds.ndim == 1: + req_embeds = req_embeds.unsqueeze(0) + elif req_embeds.ndim > 2: + req_embeds = req_embeds.reshape(-1, req_embeds.shape[-1]) + + if like is not None: + req_embeds = req_embeds.to(device=like.device, dtype=like.dtype) + return req_embeds + + def _resolve_preprocess_batch_inputs_embeds( + self, + *, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None, + preprocess_results: list[dict[str, Any]], + ) -> tuple[torch.Tensor | None, list[torch.Tensor | None]]: + if not any(result["req_embeds"] is not None for result in preprocess_results): + return None, [None] * len(preprocess_results) + + batch_inputs_embeds = inputs_embeds + if batch_inputs_embeds is None: + batch_inputs_embeds = self.inputs_embeds.gpu[: input_ids.shape[0]] + + resolved_req_embeds: list[torch.Tensor | None] = [] + for result in preprocess_results: + req_embeds = result["req_embeds"] + if req_embeds is None: + req_embeds = self._embed_input_ids_for_preprocess( + result["req_input_ids"], + like=batch_inputs_embeds[result["start"] : result["end"]], + ) + resolved_req_embeds.append(req_embeds) + + return batch_inputs_embeds, resolved_req_embeds + + def _build_overlay_intermediate_buffer( + self, + req_ids: list[str], + pending_updates: dict[str, dict[str, Any]] | None = None, + ) -> dict[str, dict[str, Any]]: + overlay_buffer: dict[str, dict[str, Any]] = {} + for req_id in req_ids: + merged = dict(self.model_intermediate_buffer.get(req_id, {})) + if pending_updates and req_id in pending_updates: + merged.update(pending_updates[req_id]) + overlay_buffer[req_id] = merged + return overlay_buffer + + def _commit_intermediate_buffer_overlay( + self, + overlay_buffer: dict[str, dict[str, Any]], + *, + invalid_req_indices: list[int] | None = None, + req_ids: list[str] | None = None, + ) -> None: + if not overlay_buffer: + return + req_ids = req_ids or list(self.input_batch.req_ids) + invalid_index_set = set(invalid_req_indices or []) + for req_index, req_id in enumerate(req_ids): + if req_index in invalid_index_set: + continue + self._update_intermediate_buffer(req_id, overlay_buffer.get(req_id, {})) From d51fb28ed5d195458fe2171b82e6590a660ebb89 Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Thu, 12 Mar 2026 09:35:04 +0800 Subject: [PATCH 02/14] Fix FunAudioChat remote validation regressions Signed-off-by: ramos.ma --- .../models/test_funaudiochat_native.py | 62 +-------------- tests/worker/test_omni_gpu_model_runner.py | 76 +++++++++++++++++++ .../models/funaudiochat/funaudiochat.py | 21 ++--- .../stage_configs/funaudiochat_s2s.yaml | 6 +- vllm_omni/worker/gpu_model_runner.py | 16 +++- 5 files changed, 103 insertions(+), 78 deletions(-) diff --git a/tests/model_executor/models/test_funaudiochat_native.py b/tests/model_executor/models/test_funaudiochat_native.py index 09786a4079c..de270e9d018 100644 --- a/tests/model_executor/models/test_funaudiochat_native.py +++ b/tests/model_executor/models/test_funaudiochat_native.py @@ -224,7 +224,6 @@ def test_postprocess_sampled_tokens_updates_buffer_from_final_sampled_token(): fac_mod._GENERATE_SPEECH_KEY: False, fac_mod._FORCE_AUDIO_BOS_KEY: False, fac_mod._FINISH_SPEECH_KEY: False, - fac_mod._RAW_TEXT_TOKEN_ID_KEY: 42, } } @@ -249,7 +248,6 @@ def test_postprocess_sampled_tokens_force_text_abos_overrides_sampled_token(): fac_mod._GENERATE_SPEECH_KEY: False, fac_mod._FORCE_AUDIO_BOS_KEY: True, fac_mod._FINISH_SPEECH_KEY: False, - fac_mod._RAW_TEXT_TOKEN_ID_KEY: 7, } } @@ -265,55 +263,6 @@ def test_postprocess_sampled_tokens_force_text_abos_overrides_sampled_token(): assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False -def test_postprocess_sampled_tokens_uses_raw_argmax_when_text_greedy_is_enabled(): - model = _make_model_stub() - sampled_token_ids = torch.tensor([7], dtype=torch.long) - model_intermediate_buffer = { - "req0": { - fac_mod._GENERATE_SPEECH_KEY: False, - fac_mod._FORCE_AUDIO_BOS_KEY: False, - fac_mod._FINISH_SPEECH_KEY: False, - fac_mod._RAW_TEXT_TOKEN_ID_KEY: 42, - } - } - - updated = model.postprocess_sampled_tokens( - sampled_token_ids=sampled_token_ids, - req_ids=["req0"], - req_id_to_index={"req0": 0}, - model_intermediate_buffer=model_intermediate_buffer, - ) - - assert updated.tolist() == [42] - assert model_intermediate_buffer["req0"][fac_mod._GENERATE_SPEECH_KEY] is True - assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False - - -def test_postprocess_sampled_tokens_respects_sampled_token_when_text_greedy_disabled(): - model = _make_model_stub() - model.sp_gen_kwargs["text_greedy"] = False - sampled_token_ids = torch.tensor([7], dtype=torch.long) - model_intermediate_buffer = { - "req0": { - fac_mod._GENERATE_SPEECH_KEY: False, - fac_mod._FORCE_AUDIO_BOS_KEY: False, - fac_mod._FINISH_SPEECH_KEY: False, - fac_mod._RAW_TEXT_TOKEN_ID_KEY: 42, - } - } - - updated = model.postprocess_sampled_tokens( - sampled_token_ids=sampled_token_ids, - req_ids=["req0"], - req_id_to_index={"req0": 0}, - model_intermediate_buffer=model_intermediate_buffer, - ) - - assert updated.tolist() == [7] - assert model_intermediate_buffer["req0"][fac_mod._GENERATE_SPEECH_KEY] is False - assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False - - def test_postprocess_sampled_tokens_overwrites_emitted_token_to_audio_eos_on_finish(): model = _make_model_stub() sampled_token_ids = torch.tensor([7], dtype=torch.long) @@ -322,7 +271,6 @@ def test_postprocess_sampled_tokens_overwrites_emitted_token_to_audio_eos_on_fin fac_mod._GENERATE_SPEECH_KEY: True, fac_mod._FORCE_AUDIO_BOS_KEY: False, fac_mod._FINISH_SPEECH_KEY: True, - fac_mod._RAW_TEXT_TOKEN_ID_KEY: 7, } } @@ -347,7 +295,6 @@ def test_postprocess_sampled_tokens_noops_for_spec_decode_shapes(): fac_mod._GENERATE_SPEECH_KEY: False, fac_mod._FORCE_AUDIO_BOS_KEY: True, fac_mod._FINISH_SPEECH_KEY: False, - fac_mod._RAW_TEXT_TOKEN_ID_KEY: 42, } } @@ -383,7 +330,7 @@ def test_chunked_prefill_preprocess_keeps_speech_inactive(): assert torch.equal(second_update["audio_token_ids"], torch.full((1, 5), -1, dtype=torch.long)) -def test_preprocess_single_token_text_decode_keeps_input_id_path(): +def test_preprocess_single_token_text_decode_returns_text_embeddings(): model = _make_model_stub() _, req_embeds, _ = model.preprocess( @@ -391,10 +338,10 @@ def test_preprocess_single_token_text_decode_keeps_input_id_path(): input_embeds=None, ) - assert req_embeds is None + assert torch.equal(req_embeds, torch.zeros((1, 4), dtype=torch.float32)) -def test_preprocess_first_speech_step_without_codec_history_keeps_input_id_path(): +def test_preprocess_first_speech_step_without_codec_history_returns_text_embeddings(): model = _make_model_stub() _, req_embeds, _ = model.preprocess( @@ -406,7 +353,7 @@ def test_preprocess_first_speech_step_without_codec_history_keeps_input_id_path( }, ) - assert req_embeds is None + assert torch.equal(req_embeds, torch.zeros((1, 4), dtype=torch.float32)) def test_preprocess_active_speech_with_codec_history_blends_audio_features(): @@ -468,7 +415,6 @@ def test_postprocess_prefill_warmup_updates_cache_without_emitting_audio(): fac_mod._FORCE_AUDIO_BOS_KEY: True, fac_mod._FINISH_SPEECH_KEY: False, fac_mod._GENERATE_SPEECH_KEY: False, - fac_mod._RAW_TEXT_TOKEN_ID_KEY: 1, fac_mod._SPEECH_IDS_KEY: torch.empty((1, 0), dtype=torch.long), "_run_prefill_crq_warmup": True, "_prefill_input_ids": torch.tensor([1, 2, 3], dtype=torch.long), diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index 1379e1d1e5b..213ab7cf836 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -87,6 +87,29 @@ def preprocess(self, input_ids, input_embeds, **info_dict): return input_ids + 100, req_embeds, {"marker_seen": info_dict.get("marker")} +class MultimodalPreprocessModel(torch.nn.Module): + """Tracks fallback raw-token slices when multimodal preprocess runs from embeds.""" + + has_preprocess = True + requires_raw_input_tokens = False + + def __init__(self, hidden_size: int = 4): + super().__init__() + self.hidden_size = hidden_size + self.observed_input_ids = [] + self.observed_input_embeds = [] + + def embed_input_ids(self, input_ids, multimodal_embeddings=None, is_multimodal=None): + del multimodal_embeddings, is_multimodal + return input_ids.to(dtype=torch.float32).unsqueeze(-1).repeat(1, self.hidden_size) + + def preprocess(self, input_ids, input_embeds, **info_dict): + self.observed_input_ids.append(input_ids.clone()) + self.observed_input_embeds.append(input_embeds.clone() if isinstance(input_embeds, torch.Tensor) else None) + req_embeds = input_ids.to(dtype=torch.float32).unsqueeze(-1).repeat(1, self.hidden_size) + return input_ids, req_embeds, {"marker_seen": info_dict.get("marker")} + + class DummyTalkerMTP(torch.nn.Module): """A fake talker_mtp module for deterministic CPU testing.""" @@ -190,6 +213,21 @@ def _make_preprocess_runner(model, hidden_size=4): return runner +def _make_mm_preprocess_runner(model, hidden_size=4): + runner = _make_preprocess_runner(model, hidden_size=hidden_size) + runner.supports_mm_inputs = True + runner.encoder_cache = None + runner._execute_mm_encoder = lambda scheduler_output: None + runner._gather_mm_embeddings = lambda scheduler_output: (None, None) + runner._prepare_mm_inputs = lambda num_input_tokens: ( + None, + runner.inputs_embeds.gpu[:num_input_tokens], + ) + runner._extract_mm_kwargs = lambda scheduler_output: {} + runner.maybe_get_ec_connector_output = _noop_forward_context + return runner + + class StopAfterBookkeepingError(Exception): pass @@ -472,3 +510,41 @@ def test_preprocess_passes_none_input_embeds_for_raw_token_models(monkeypatch): ), ) assert runner.model_intermediate_buffer["r1"]["marker_seen"] == "r1" + + +def test_preprocess_uses_buffered_input_ids_when_multimodal_path_returns_none(monkeypatch): + import vllm_omni.worker.gpu_model_runner as mod + + monkeypatch.setattr(mod, "get_pp_group", lambda: SimpleNamespace(is_first_rank=True)) + + runner = _make_mm_preprocess_runner(MultimodalPreprocessModel(hidden_size=4), hidden_size=4) + scheduler_output = SimpleNamespace( + total_num_scheduled_tokens=2, + num_scheduled_tokens={"r1": 2}, + scheduled_encoder_inputs=None, + ) + + input_ids, inputs_embeds, *_ = OmniGPUModelRunner._preprocess( + runner, + scheduler_output, + num_input_tokens=2, + ) + + assert input_ids is None + assert len(runner.model.observed_input_ids) == 1 + assert torch.equal(runner.model.observed_input_ids[0], torch.tensor([1, 2], dtype=torch.int32)) + assert torch.equal( + runner.model.observed_input_embeds[0], + runner.inputs_embeds.gpu[:2], + ) + assert torch.equal( + inputs_embeds, + torch.tensor( + [ + [1.0, 1.0, 1.0, 1.0], + [2.0, 2.0, 2.0, 2.0], + ], + dtype=torch.float32, + ), + ) + assert runner.model_intermediate_buffer["r1"]["marker_seen"] == "r1" diff --git a/vllm_omni/model_executor/models/funaudiochat/funaudiochat.py b/vllm_omni/model_executor/models/funaudiochat/funaudiochat.py index 83883d0a1e7..86c69bb144b 100644 --- a/vllm_omni/model_executor/models/funaudiochat/funaudiochat.py +++ b/vllm_omni/model_executor/models/funaudiochat/funaudiochat.py @@ -63,7 +63,6 @@ _FORCE_AUDIO_BOS_KEY = "funaudiochat_force_audio_bos_pending" _FINISH_SPEECH_KEY = "funaudiochat_finish_speech" _GENERATE_SPEECH_KEY = "funaudiochat_generate_speech" -_RAW_TEXT_TOKEN_ID_KEY = "funaudiochat_raw_text_token_id" _SPEECH_IDS_KEY = "funaudiochat_speech_ids" _TEXT_INPUT_IDS_KEY = "funaudiochat_text_input_ids" _TEXT_SEQ_LEN_KEY = "funaudiochat_text_seq_len" @@ -73,7 +72,7 @@ class FunAudioChatForConditionalGeneration(_NativeFunAudioChatBase, SupportsMultiModal): supports_multimodal_raw_input_only = True supports_multimodal = True - requires_raw_input_tokens = True + requires_raw_input_tokens = False input_modalities = "audio" pooler_output_buffer_keys = ("audio_token_ids",) @@ -408,6 +407,7 @@ def preprocess( input_embeds: torch.Tensor | None, **info_dict: Any, ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + del input_embeds if not self._batch_preprocess_in_progress: self._batch_req_infos = [] self._batch_sidecar_results = [] @@ -416,9 +416,8 @@ def preprocess( span_len = int(input_ids.shape[0]) device = input_ids.device - req_embeds = input_embeds - if req_embeds is None and span_len > 1: - req_embeds = self.get_language_model().embed_input_ids(input_ids.reshape(-1)) + text_embeds = self.get_language_model().embed_input_ids(input_ids.reshape(-1)) + req_embeds = text_embeds generate_speech = bool(info_dict.get(_GENERATE_SPEECH_KEY, False)) force_audio_bos_pending = bool(info_dict.get(_FORCE_AUDIO_BOS_KEY, self.sp_gen_kwargs["force_text_abos"])) @@ -429,7 +428,7 @@ def preprocess( ) if span_len == 1: - current_text_embed = self.get_language_model().embed_input_ids(input_ids.reshape(-1)).reshape(1, -1) + current_text_embed = text_embeds.reshape(1, -1) if generate_speech and speech_ids.shape[-1] >= int(self.config.audio_config.group_size): last_group = speech_ids[:, -int(self.config.audio_config.group_size) :] audio_features = self.audio_tower(last_group.to(device=device, dtype=torch.long)) @@ -474,12 +473,10 @@ def compute_logits( self._batch_preprocess_in_progress = False return None - raw_argmax_token_ids = torch.argmax(logits, dim=-1).detach().to("cpu") self._batch_sidecar_results = [] for idx, req_info in enumerate(self._batch_req_infos): force_audio_bos_pending = bool(req_info.get(_FORCE_AUDIO_BOS_KEY, False)) speech_active = bool(req_info.get(_GENERATE_SPEECH_KEY, False)) - raw_text_token_id = int(raw_argmax_token_ids[idx].item()) sidecar_result = { _AUDIO_TOKEN_IDS_KEY: self._empty_audio_token_ids(hidden_states.device).to("cpu"), @@ -488,7 +485,6 @@ def compute_logits( _FORCE_AUDIO_BOS_KEY: force_audio_bos_pending, _FINISH_SPEECH_KEY: False, _GENERATE_SPEECH_KEY: speech_active, - _RAW_TEXT_TOKEN_ID_KEY: raw_text_token_id, _SPEECH_IDS_KEY: req_info.get(_SPEECH_IDS_KEY), "audio_token_ids": self._empty_audio_token_ids(hidden_states.device).to("cpu"), } @@ -559,7 +555,6 @@ def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: _FORCE_AUDIO_BOS_KEY: sidecar_result[_FORCE_AUDIO_BOS_KEY], _FINISH_SPEECH_KEY: sidecar_result[_FINISH_SPEECH_KEY], _GENERATE_SPEECH_KEY: sidecar_result[_GENERATE_SPEECH_KEY], - _RAW_TEXT_TOKEN_ID_KEY: sidecar_result[_RAW_TEXT_TOKEN_ID_KEY], _SPEECH_IDS_KEY: sidecar_result[_SPEECH_IDS_KEY], "audio_token_ids": sidecar_result["audio_token_ids"], } @@ -592,16 +587,12 @@ def postprocess_sampled_tokens( token_slot = updated_token_ids[idx] if updated_token_ids.ndim == 1 else updated_token_ids[idx, 0] original_token_id = int(token_slot.item()) - sampled_token_id = original_token_id - raw_text_token_id = req_buffer.get(_RAW_TEXT_TOKEN_ID_KEY) - if self.sp_gen_kwargs["text_greedy"] and raw_text_token_id is not None: - sampled_token_id = int(raw_text_token_id) speech_active = bool(req_buffer.get(_GENERATE_SPEECH_KEY, False)) force_audio_bos_pending = bool(req_buffer.get(_FORCE_AUDIO_BOS_KEY, False)) finish_speech = bool(req_buffer.pop(_FINISH_SPEECH_KEY, False)) final_token_id, next_speech_active, next_force_audio_bos_pending = self._resolve_next_speech_state( - sampled_token_id=sampled_token_id, + sampled_token_id=original_token_id, generate_speech=speech_active, finish_speech=finish_speech, force_audio_bos_pending=force_audio_bos_pending, diff --git a/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml b/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml index aca184567f0..07a8e84eea6 100644 --- a/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml +++ b/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml @@ -23,9 +23,9 @@ stage_args: final_output_type: text is_comprehension: true default_sampling_params: - temperature: 0.8 - top_p: 0.9 - top_k: 0 + temperature: 0.0 + top_p: 1.0 + top_k: -1 repetition_penalty: 1.2 max_tokens: 2048 stop_token_ids: [151645] diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index c66fa7f7ee3..00ac48858e6 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1250,6 +1250,12 @@ def _preprocess( # collect additional_information (tensor/list) for prefill portion only decode_req_ids = [] preprocess_results: list[dict[str, Any]] = [] + preprocess_input_ids = input_ids + if preprocess_input_ids is None: + # Multimodal stages can enter preprocess with embed-only inputs, + # but model-local preprocess may still require the raw token ids + # for the scheduled span. + preprocess_input_ids = self.input_ids.gpu[:num_input_tokens] for req_index, req_id in enumerate(self.input_batch.req_ids): req_infos = self.model_intermediate_buffer.get(req_id, {}) @@ -1265,7 +1271,9 @@ def _preprocess( # call the custom process function embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None req_input_ids, req_embeds, update_dict = self.model.preprocess( - input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos + input_ids=preprocess_input_ids[s:e], + input_embeds=embed_slice, + **req_infos, ) preprocess_results.append( { @@ -1313,7 +1321,11 @@ def _preprocess( assert resolved_req_embeds_item is not None seg_len = min(span_len, resolved_req_embeds_item.shape[0]) inputs_embeds[s : s + seg_len] = resolved_req_embeds_item[:seg_len] - if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == span_len: + if ( + input_ids is not None + and isinstance(req_input_ids, torch.Tensor) + and req_input_ids.numel() == span_len + ): input_ids[s : s + span_len] = req_input_ids # run talker mtp decode From 573e4cecde4dae5c08c0d8c437ec0a5d1ac251e0 Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Thu, 12 Mar 2026 12:27:03 +0800 Subject: [PATCH 03/14] Add FunAudioChat sampler regression test Signed-off-by: ramos.ma --- .../models/test_funaudiochat_native.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/model_executor/models/test_funaudiochat_native.py b/tests/model_executor/models/test_funaudiochat_native.py index de270e9d018..b2f8b8ba421 100644 --- a/tests/model_executor/models/test_funaudiochat_native.py +++ b/tests/model_executor/models/test_funaudiochat_native.py @@ -240,6 +240,30 @@ def test_postprocess_sampled_tokens_updates_buffer_from_final_sampled_token(): assert fac_mod._FINISH_SPEECH_KEY not in model_intermediate_buffer["req0"] +def test_postprocess_sampled_tokens_preserves_regular_sampler_token(): + model = _make_model_stub() + sampled_token_ids = torch.tensor([7], dtype=torch.long) + model_intermediate_buffer = { + "req0": { + fac_mod._GENERATE_SPEECH_KEY: False, + fac_mod._FORCE_AUDIO_BOS_KEY: False, + fac_mod._FINISH_SPEECH_KEY: False, + } + } + + updated = model.postprocess_sampled_tokens( + sampled_token_ids=sampled_token_ids, + req_ids=["req0"], + req_id_to_index={"req0": 0}, + model_intermediate_buffer=model_intermediate_buffer, + ) + + assert updated.tolist() == [7] + assert model_intermediate_buffer["req0"][fac_mod._GENERATE_SPEECH_KEY] is False + assert model_intermediate_buffer["req0"][fac_mod._FORCE_AUDIO_BOS_KEY] is False + assert fac_mod._FINISH_SPEECH_KEY not in model_intermediate_buffer["req0"] + + def test_postprocess_sampled_tokens_force_text_abos_overrides_sampled_token(): model = _make_model_stub() sampled_token_ids = torch.tensor([7], dtype=torch.long) From 925197eab330bde2b27a4a1662087a329f1fe99a Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Thu, 12 Mar 2026 14:28:10 +0800 Subject: [PATCH 04/14] Reduce FunAudioChat runner scope Signed-off-by: mayufeng --- vllm_omni/worker/gpu_ar_model_runner.py | 78 +++++++++++- vllm_omni/worker/gpu_model_runner.py | 159 ++---------------------- 2 files changed, 83 insertions(+), 154 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index ad9a583edd9..484ca7f90b9 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -417,6 +417,76 @@ def _postprocess_sampled_token_ids( ) return sampled_token_ids if corrected_sampled_token_ids is None else corrected_sampled_token_ids + def _collect_pending_postprocess_updates( + self, + hidden_states: torch.Tensor, + multimodal_outputs: object, + num_scheduled_tokens_np: np.ndarray, + ) -> dict[str, dict[str, Any]]: + updates: dict[str, dict[str, Any]] = {} + try: + if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_infos = self.model_intermediate_buffer.get(req_id, {}) + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + hidden_states_slice = hidden_states[s:e] + update_dict = self.model.postprocess(hidden_states_slice, **req_infos) + normalized = self._normalize_intermediate_update(update_dict) + if normalized: + updates[req_id] = normalized + except Exception as e: + logger.error( + f"Error merging for requests:{self.input_batch.req_ids} " + f"additional information update: {e}, with the multimodal_outputs " + f"as {multimodal_outputs}" + ) + import traceback + + traceback.print_exc() + return updates + + def _collect_additional_information_updates( + self, + hidden_states: torch.Tensor, + multimodal_outputs: object, + num_scheduled_tokens_np: np.ndarray, + scheduler_output: SchedulerOutput | None = None, + ) -> dict[str, dict[str, Any]]: + del scheduler_output + return self._collect_pending_postprocess_updates( + hidden_states, + multimodal_outputs, + num_scheduled_tokens_np, + ) + + def _build_overlay_intermediate_buffer_local( + self, + pending_updates: dict[str, dict[str, Any]], + ) -> dict[str, dict[str, Any]]: + overlay_buffer: dict[str, dict[str, Any]] = {} + for req_id in self.input_batch.req_ids: + merged = dict(self.model_intermediate_buffer.get(req_id, {})) + if req_id in pending_updates: + merged.update(pending_updates[req_id]) + overlay_buffer[req_id] = merged + return overlay_buffer + + def _commit_overlay_intermediate_buffer_local( + self, + overlay_buffer: dict[str, dict[str, Any]], + *, + invalid_req_indices: list[int] | None = None, + ) -> None: + invalid_index_set = set(invalid_req_indices or []) + for req_index, req_id in enumerate(self.input_batch.req_ids): + if req_index in invalid_index_set: + continue + update_dict = overlay_buffer.get(req_id, {}) + if update_dict: + self._update_intermediate_buffer(req_id, update_dict) + @torch.inference_mode() def sample_tokens( self, @@ -486,10 +556,7 @@ def sample_tokens( overlay_intermediate_buffer = None postprocess_hook = getattr(self.model, "postprocess_sampled_tokens", None) if pending_intermediate_updates or postprocess_hook is not None: - overlay_intermediate_buffer = self._build_overlay_intermediate_buffer( - self.input_batch.req_ids.copy(), - pending_intermediate_updates, - ) + overlay_intermediate_buffer = self._build_overlay_intermediate_buffer_local(pending_intermediate_updates) with record_function_or_nullcontext("gpu_model_runner: postprocess_sampled_tokens"): sampler_output.sampled_token_ids = self._postprocess_sampled_token_ids( @@ -573,10 +640,9 @@ def propose_draft_token_ids(sampled_token_ids): ) if overlay_intermediate_buffer is not None: - self._commit_intermediate_buffer_overlay( + self._commit_overlay_intermediate_buffer_local( overlay_intermediate_buffer, invalid_req_indices=invalid_req_indices, - req_ids=self.input_batch.req_ids.copy(), ) if propose_drafts_after_bookkeeping: diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 00ac48858e6..c0e4d82862f 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1006,26 +1006,6 @@ def _process_additional_information_updates( scheduler_output: "SchedulerOutput", ) -> None: """Process model-provided per-request updates and merge into model_intermediate_buffer.""" - updates = self._collect_additional_information_updates( - hidden_states, - multimodal_outputs, - num_scheduled_tokens_np, - scheduler_output, - ) - if not updates: - return - overlay_buffer = self._build_overlay_intermediate_buffer(list(updates), updates) - self._commit_intermediate_buffer_overlay(overlay_buffer, req_ids=list(updates)) - - def _collect_additional_information_updates( - self, - hidden_states: torch.Tensor, - multimodal_outputs: object, - num_scheduled_tokens_np: np.ndarray, - scheduler_output: "SchedulerOutput", - ) -> dict[str, dict[str, Any]]: - """Collect model-provided per-request updates without mutating runtime state.""" - updates: dict[str, dict[str, Any]] = {} try: # execute the custom postprocess function # TODO(Peiqi): do we have a more elegant way to do this? @@ -1038,9 +1018,7 @@ def _collect_additional_information_updates( # only consider to store data into update dict. hidden_states_slice = hidden_states[s:e] update_dict = self.model.postprocess(hidden_states_slice, **req_infos) - normalized = self._normalize_intermediate_update(update_dict) - if normalized: - updates[req_id] = normalized + self._update_intermediate_buffer(req_id, update_dict) except Exception as e: logger.error( f"Error merging for requests:{self.input_batch.req_ids} " @@ -1050,7 +1028,6 @@ def _collect_additional_information_updates( import traceback traceback.print_exc() - return updates def _collect_additional_information_for_prefill( self, @@ -1249,7 +1226,6 @@ def _preprocess( # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only decode_req_ids = [] - preprocess_results: list[dict[str, Any]] = [] preprocess_input_ids = input_ids if preprocess_input_ids is None: # Multimodal stages can enter preprocess with embed-only inputs, @@ -1275,41 +1251,17 @@ def _preprocess( input_embeds=embed_slice, **req_infos, ) - preprocess_results.append( - { - "req_id": req_id, - "start": s, - "end": e, - "span_len": span_len, - "req_input_ids": req_input_ids, - "req_embeds": req_embeds, - "update_dict": update_dict, - } - ) - - inputs_embeds, resolved_req_embeds = self._resolve_preprocess_batch_inputs_embeds( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - preprocess_results=preprocess_results, - ) - - for result, resolved_req_embeds_item in zip(preprocess_results, resolved_req_embeds, strict=False): - req_id = result["req_id"] - s = result["start"] - span_len = result["span_len"] - req_input_ids = result["req_input_ids"] - update_dict = result["update_dict"] + if req_embeds is None: + raise RuntimeError( + "Model preprocess must return req_embeds when has_preprocess=True, " + f"but got None for request {req_id}." + ) if hasattr(self.model, "talker_mtp") and span_len == 1: - if resolved_req_embeds_item is None: - raise RuntimeError( - "talker_mtp requires preprocess embeddings for decode steps, " - f"but model.preprocess returned req_embeds=None for request {req_id}." - ) last_talker_hidden, text_step = update_dict.pop("mtp_inputs") decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) - self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(resolved_req_embeds_item) + self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) self.text_step.gpu[decode_slice].copy_(text_step) decode_req_ids.append(req_id) @@ -1317,10 +1269,11 @@ def _preprocess( # TODO(Peiqi): the merge stage could move out from the critical path self._merge_additional_information_update(req_id, update_dict) + if inputs_embeds is None: + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] if inputs_embeds is not None: - assert resolved_req_embeds_item is not None - seg_len = min(span_len, resolved_req_embeds_item.shape[0]) - inputs_embeds[s : s + seg_len] = resolved_req_embeds_item[:seg_len] + seg_len = min(span_len, req_embeds.shape[0]) + inputs_embeds[s : s + seg_len] = req_embeds[:seg_len] if ( input_ids is not None and isinstance(req_input_ids, torch.Tensor) @@ -1438,93 +1391,3 @@ def _normalize_intermediate_update(self, upd: dict[str, Any] | None) -> dict[str if not isinstance(upd, dict) or not upd: return {} return {k: self._normalize_intermediate_update_value(v) for k, v in upd.items()} - - def _embed_input_ids_for_preprocess( - self, - input_ids: torch.Tensor, - *, - like: torch.Tensor | None = None, - ) -> torch.Tensor: - flat_input_ids = input_ids.reshape(-1) - - embed_input_ids = getattr(self.model, "embed_input_ids", None) - if callable(embed_input_ids): - req_embeds = embed_input_ids(input_ids=flat_input_ids) - else: - get_language_model = getattr(self.model, "get_language_model", None) - language_model = get_language_model() if callable(get_language_model) else None - lm_embed_input_ids = getattr(language_model, "embed_input_ids", None) - if callable(lm_embed_input_ids): - req_embeds = lm_embed_input_ids(flat_input_ids) - else: - get_input_embeddings = getattr(self.model, "get_input_embeddings", None) - if not callable(get_input_embeddings): - raise RuntimeError( - "Model preprocess returned req_embeds=None, but the runner " - "could not resolve a token embedding function." - ) - req_embeds = get_input_embeddings()(flat_input_ids) - - if req_embeds.ndim == 1: - req_embeds = req_embeds.unsqueeze(0) - elif req_embeds.ndim > 2: - req_embeds = req_embeds.reshape(-1, req_embeds.shape[-1]) - - if like is not None: - req_embeds = req_embeds.to(device=like.device, dtype=like.dtype) - return req_embeds - - def _resolve_preprocess_batch_inputs_embeds( - self, - *, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor | None, - preprocess_results: list[dict[str, Any]], - ) -> tuple[torch.Tensor | None, list[torch.Tensor | None]]: - if not any(result["req_embeds"] is not None for result in preprocess_results): - return None, [None] * len(preprocess_results) - - batch_inputs_embeds = inputs_embeds - if batch_inputs_embeds is None: - batch_inputs_embeds = self.inputs_embeds.gpu[: input_ids.shape[0]] - - resolved_req_embeds: list[torch.Tensor | None] = [] - for result in preprocess_results: - req_embeds = result["req_embeds"] - if req_embeds is None: - req_embeds = self._embed_input_ids_for_preprocess( - result["req_input_ids"], - like=batch_inputs_embeds[result["start"] : result["end"]], - ) - resolved_req_embeds.append(req_embeds) - - return batch_inputs_embeds, resolved_req_embeds - - def _build_overlay_intermediate_buffer( - self, - req_ids: list[str], - pending_updates: dict[str, dict[str, Any]] | None = None, - ) -> dict[str, dict[str, Any]]: - overlay_buffer: dict[str, dict[str, Any]] = {} - for req_id in req_ids: - merged = dict(self.model_intermediate_buffer.get(req_id, {})) - if pending_updates and req_id in pending_updates: - merged.update(pending_updates[req_id]) - overlay_buffer[req_id] = merged - return overlay_buffer - - def _commit_intermediate_buffer_overlay( - self, - overlay_buffer: dict[str, dict[str, Any]], - *, - invalid_req_indices: list[int] | None = None, - req_ids: list[str] | None = None, - ) -> None: - if not overlay_buffer: - return - req_ids = req_ids or list(self.input_batch.req_ids) - invalid_index_set = set(invalid_req_indices or []) - for req_index, req_id in enumerate(req_ids): - if req_index in invalid_index_set: - continue - self._update_intermediate_buffer(req_id, overlay_buffer.get(req_id, {})) From 2105386692f0e655ba7e3f4fb50370740c5f50a2 Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Thu, 12 Mar 2026 16:44:50 +0800 Subject: [PATCH 05/14] Trim nonessential runner test scaffolding Signed-off-by: mayufeng --- tests/worker/test_omni_gpu_model_runner.py | 152 --------------------- 1 file changed, 152 deletions(-) diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index 213ab7cf836..60d6610ca46 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -70,46 +70,6 @@ def postprocess_sampled_tokens(self, sampled_token_ids, req_ids, req_id_to_index return None -class RawTokenPreprocessModel(torch.nn.Module): - """Tracks whether preprocess receives raw input ids without an embed slice.""" - - has_preprocess = True - requires_raw_input_tokens = True - - def __init__(self, hidden_size: int = 4): - super().__init__() - self.hidden_size = hidden_size - self.observed_input_embeds = [] - - def preprocess(self, input_ids, input_embeds, **info_dict): - self.observed_input_embeds.append(input_embeds) - req_embeds = input_ids.to(dtype=torch.float32).unsqueeze(-1).repeat(1, self.hidden_size) - return input_ids + 100, req_embeds, {"marker_seen": info_dict.get("marker")} - - -class MultimodalPreprocessModel(torch.nn.Module): - """Tracks fallback raw-token slices when multimodal preprocess runs from embeds.""" - - has_preprocess = True - requires_raw_input_tokens = False - - def __init__(self, hidden_size: int = 4): - super().__init__() - self.hidden_size = hidden_size - self.observed_input_ids = [] - self.observed_input_embeds = [] - - def embed_input_ids(self, input_ids, multimodal_embeddings=None, is_multimodal=None): - del multimodal_embeddings, is_multimodal - return input_ids.to(dtype=torch.float32).unsqueeze(-1).repeat(1, self.hidden_size) - - def preprocess(self, input_ids, input_embeds, **info_dict): - self.observed_input_ids.append(input_ids.clone()) - self.observed_input_embeds.append(input_embeds.clone() if isinstance(input_embeds, torch.Tensor) else None) - req_embeds = input_ids.to(dtype=torch.float32).unsqueeze(-1).repeat(1, self.hidden_size) - return input_ids, req_embeds, {"marker_seen": info_dict.get("marker")} - - class DummyTalkerMTP(torch.nn.Module): """A fake talker_mtp module for deterministic CPU testing.""" @@ -188,46 +148,6 @@ class _DummyVllmConfig: return runner -def _make_preprocess_runner(model, hidden_size=4): - runner = object.__new__(OmniGPUModelRunner) - runner.model = model - runner.model_config = SimpleNamespace(is_encoder_decoder=False) - runner.supports_mm_inputs = False - runner.enable_prompt_embeds = False - runner.uses_mrope = False - runner.uses_xdrope_dim = 0 - runner.positions = DummyBuffer(torch.arange(8, dtype=torch.int64)) - runner.input_ids = DummyBuffer(torch.tensor([1, 2, 3, 4], dtype=torch.int32)) - runner.inputs_embeds = DummyBuffer(torch.full((4, hidden_size), -1.0, dtype=torch.float32)) - runner.input_batch = SimpleNamespace( - req_ids=["r1"], - num_computed_tokens_cpu=np.array([0], dtype=np.int32), - ) - runner.requests = {"r1": SimpleNamespace(prompt_token_ids=[], mm_features=[])} - runner.model_intermediate_buffer = {"r1": {"marker": "r1"}} - runner.query_start_loc = SimpleNamespace(cpu=torch.tensor([0], dtype=torch.int32)) - runner.dtype = torch.float32 - runner.device = torch.device("cpu") - runner.vllm_config = SimpleNamespace(model_config=SimpleNamespace(async_chunk=False)) - runner._init_model_kwargs = lambda: {} - return runner - - -def _make_mm_preprocess_runner(model, hidden_size=4): - runner = _make_preprocess_runner(model, hidden_size=hidden_size) - runner.supports_mm_inputs = True - runner.encoder_cache = None - runner._execute_mm_encoder = lambda scheduler_output: None - runner._gather_mm_embeddings = lambda scheduler_output: (None, None) - runner._prepare_mm_inputs = lambda num_input_tokens: ( - None, - runner.inputs_embeds.gpu[:num_input_tokens], - ) - runner._extract_mm_kwargs = lambda scheduler_output: {} - runner.maybe_get_ec_connector_output = _noop_forward_context - return runner - - class StopAfterBookkeepingError(Exception): pass @@ -476,75 +396,3 @@ def fake_bookkeeping( "r2": {"token": 2, "pending": 22}, } assert captured["buffer_before_bookkeeping"] == {"r1": {"token": 1}, "r2": {"token": 2}} - - -def test_preprocess_passes_none_input_embeds_for_raw_token_models(monkeypatch): - import vllm_omni.worker.gpu_model_runner as mod - - monkeypatch.setattr(mod, "get_pp_group", lambda: SimpleNamespace(is_first_rank=True)) - - runner = _make_preprocess_runner(RawTokenPreprocessModel(hidden_size=4), hidden_size=4) - scheduler_output = SimpleNamespace( - total_num_scheduled_tokens=2, - num_scheduled_tokens={"r1": 2}, - scheduled_encoder_inputs=None, - ) - - input_ids, inputs_embeds, *_ = OmniGPUModelRunner._preprocess( - runner, - scheduler_output, - num_input_tokens=2, - ) - - assert runner.model.observed_input_embeds == [None] - assert torch.equal(input_ids, torch.tensor([101, 102], dtype=torch.int32)) - assert inputs_embeds.data_ptr() == runner.inputs_embeds.gpu[:2].data_ptr() - assert torch.equal( - inputs_embeds, - torch.tensor( - [ - [1.0, 1.0, 1.0, 1.0], - [2.0, 2.0, 2.0, 2.0], - ], - dtype=torch.float32, - ), - ) - assert runner.model_intermediate_buffer["r1"]["marker_seen"] == "r1" - - -def test_preprocess_uses_buffered_input_ids_when_multimodal_path_returns_none(monkeypatch): - import vllm_omni.worker.gpu_model_runner as mod - - monkeypatch.setattr(mod, "get_pp_group", lambda: SimpleNamespace(is_first_rank=True)) - - runner = _make_mm_preprocess_runner(MultimodalPreprocessModel(hidden_size=4), hidden_size=4) - scheduler_output = SimpleNamespace( - total_num_scheduled_tokens=2, - num_scheduled_tokens={"r1": 2}, - scheduled_encoder_inputs=None, - ) - - input_ids, inputs_embeds, *_ = OmniGPUModelRunner._preprocess( - runner, - scheduler_output, - num_input_tokens=2, - ) - - assert input_ids is None - assert len(runner.model.observed_input_ids) == 1 - assert torch.equal(runner.model.observed_input_ids[0], torch.tensor([1, 2], dtype=torch.int32)) - assert torch.equal( - runner.model.observed_input_embeds[0], - runner.inputs_embeds.gpu[:2], - ) - assert torch.equal( - inputs_embeds, - torch.tensor( - [ - [1.0, 1.0, 1.0, 1.0], - [2.0, 2.0, 2.0, 2.0], - ], - dtype=torch.float32, - ), - ) - assert runner.model_intermediate_buffer["r1"]["marker_seen"] == "r1" From 0a7485caa90bcf01c31e6127d3671a19ea39976a Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Thu, 12 Mar 2026 17:06:44 +0800 Subject: [PATCH 06/14] Reduce shared runner scope for FunAudioChat Signed-off-by: mayufeng --- vllm_omni/worker/gpu_ar_model_runner.py | 17 +++++++++- .../worker/gpu_generation_model_runner.py | 14 ++++---- vllm_omni/worker/gpu_model_runner.py | 34 +++++++------------ 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 484ca7f90b9..4ecbfdc7082 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -417,6 +417,21 @@ def _postprocess_sampled_token_ids( ) return sampled_token_ids if corrected_sampled_token_ids is None else corrected_sampled_token_ids + @staticmethod + def _normalize_overlay_value(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().to("cpu").contiguous() + if isinstance(value, list): + return [GPUARModelRunner._normalize_overlay_value(item) for item in value] + if isinstance(value, tuple): + return tuple(GPUARModelRunner._normalize_overlay_value(item) for item in value) + return value + + def _normalize_overlay_update(self, upd: dict[str, Any] | None) -> dict[str, Any]: + if not isinstance(upd, dict) or not upd: + return {} + return {key: self._normalize_overlay_value(value) for key, value in upd.items()} + def _collect_pending_postprocess_updates( self, hidden_states: torch.Tensor, @@ -433,7 +448,7 @@ def _collect_pending_postprocess_updates( s, e = start_offset, start_offset + sched_tokens hidden_states_slice = hidden_states[s:e] update_dict = self.model.postprocess(hidden_states_slice, **req_infos) - normalized = self._normalize_intermediate_update(update_dict) + normalized = self._normalize_overlay_update(update_dict) if normalized: updates[req_id] = normalized except Exception as e: diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 7c9de942fd1..05785d7aa6c 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -265,9 +265,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - # When spec decode is enabled, delay connector finalization until - # after draft model runs in sample_tokens. - defer_finalize = self.speculative_config is not None + # When spec decode is enabled, delay clearing connector metadata + # until after draft model runs in sample_tokens. + clear_kv_metadata = self.speculative_config is None with ( set_forward_context( attn_metadata, @@ -280,7 +280,9 @@ def execute_model( slot_mapping=slot_mappings, # OMNI: required for KV cache operations ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output, defer_finalize=defer_finalize) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, clear_metadata=clear_kv_metadata + ) as kv_connector_output, ): outputs = self._run_generation_model( input_ids=input_ids, @@ -349,9 +351,9 @@ def sample_tokens( ) = self.execute_model_state self.execute_model_state = None - # Finalize KV connector after draft model runs (if spec decode). + # Clear KV connector metadata after draft model runs (if spec decode). if self.speculative_config is not None: - self.finalize_kv_connector() + self.clear_kv_connector_metadata() pooler_output: list[object] = [] if isinstance(multimodal_outputs, torch.Tensor): diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index c0e4d82862f..3238f1108c7 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1353,41 +1353,31 @@ def _model_forward( return model_output def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: - normalized = self._normalize_intermediate_update(upd) - if not normalized: + if not isinstance(upd, dict) or not upd: return req_state = self.requests.get(req_id) if req_state is None: return - # Preserve upstream GPU-resident buffer behavior for models that - # explicitly opt in, while keeping normalized CPU values elsewhere. + # Check if the model declares keys that should stay on GPU gpu_keys: set[str] = set() if hasattr(self, "model") and hasattr(self.model, "gpu_resident_buffer_keys"): gpu_keys = self.model.gpu_resident_buffer_keys existing = self.model_intermediate_buffer.setdefault(req_id, {}) for key, value in upd.items(): - if key in gpu_keys and isinstance(value, torch.Tensor): - existing[key] = value.detach().clone() + if isinstance(value, torch.Tensor): + if key in gpu_keys: + existing[key] = value.detach().clone() + else: + existing[key] = value.detach().to("cpu").contiguous() + elif isinstance(value, list): + existing[key] = [ + (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in value + ] else: - existing[key] = normalized[key] + existing[key] = value # Backward compatible: mirror to old setattr location setattr(req_state, "additional_information_cpu", existing) def _merge_additional_information_update(self, req_id, upd): logger.warning_once("_merge_additional_information_update is deprecated, use _update_intermediate_buffer") return self._update_intermediate_buffer(req_id, upd) - - @staticmethod - def _normalize_intermediate_update_value(value: Any) -> Any: - if isinstance(value, torch.Tensor): - return value.detach().to("cpu").contiguous() - if isinstance(value, list): - return [OmniGPUModelRunner._normalize_intermediate_update_value(item) for item in value] - if isinstance(value, tuple): - return tuple(OmniGPUModelRunner._normalize_intermediate_update_value(item) for item in value) - return value - - def _normalize_intermediate_update(self, upd: dict[str, Any] | None) -> dict[str, Any]: - if not isinstance(upd, dict) or not upd: - return {} - return {k: self._normalize_intermediate_update_value(v) for k, v in upd.items()} From 567e6d567914778e322e0aee2b569a9bb7562525 Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Thu, 12 Mar 2026 17:50:11 +0800 Subject: [PATCH 07/14] Trim shared runner scope for FunAudioChat Signed-off-by: ramos.ma Signed-off-by: mayufeng --- tests/worker/test_omni_gpu_model_runner.py | 146 ------------------ .../funaudiochat/funaudiochat_code2wav.py | 4 +- vllm_omni/worker/gpu_ar_model_runner.py | 20 +-- .../worker/gpu_generation_model_runner.py | 12 +- vllm_omni/worker/gpu_model_runner.py | 47 +++--- 5 files changed, 31 insertions(+), 198 deletions(-) diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index 60d6610ca46..b2d61931558 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -1,12 +1,9 @@ from contextlib import contextmanager from types import SimpleNamespace -import numpy as np import pytest import torch -from vllm.v1.outputs import SamplerOutput -from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -41,35 +38,6 @@ def __init__(self): # No real forward needed for these tests. -class ReplaceSampledTokensModel(torch.nn.Module): - """Returns a replacement sampled-token tensor from the post-sample hook.""" - - def __init__(self): - super().__init__() - self.observed_sampled_token_ids = None - - def postprocess_sampled_tokens(self, sampled_token_ids, req_ids, req_id_to_index, model_intermediate_buffer): - assert req_ids == ["r1", "r2"] - assert req_id_to_index == {"r1": 0, "r2": 1} - assert model_intermediate_buffer == {"r1": {"token": 1}, "r2": {"token": 2}} - self.observed_sampled_token_ids = sampled_token_ids.clone() - return sampled_token_ids + 10 - - -class OverlaySampledTokensModel(torch.nn.Module): - """Validates that post-sample hooks receive overlaid pending updates.""" - - def __init__(self): - super().__init__() - self.observed_buffer = None - self.pooler_output_buffer_keys = ("audio_token_ids",) - - def postprocess_sampled_tokens(self, sampled_token_ids, req_ids, req_id_to_index, model_intermediate_buffer): - del sampled_token_ids, req_ids, req_id_to_index - self.observed_buffer = model_intermediate_buffer - return None - - class DummyTalkerMTP(torch.nn.Module): """A fake talker_mtp module for deterministic CPU testing.""" @@ -148,58 +116,6 @@ class _DummyVllmConfig: return runner -class StopAfterBookkeepingError(Exception): - pass - - -def _make_sample_tokens_runner(model): - runner = object.__new__(GPUARModelRunner) - runner.model = model - runner.speculative_config = None - runner.use_async_scheduling = False - runner.input_batch = SimpleNamespace( - req_ids=["r1", "r2"], - req_id_to_index={"r1": 0, "r2": 1}, - sampling_metadata=SimpleNamespace(no_penalties=True), - prev_sampled_token_ids=None, - num_tokens_no_spec=np.array([1, 1], dtype=np.int32), - token_ids_cpu=np.array([[1, 0, 0, 0], [2, 0, 0, 0]], dtype=np.int32), - vocab_size=32000, - ) - runner.model_intermediate_buffer = {"r1": {"token": 1}, "r2": {"token": 2}} - runner.requests = { - "r1": SimpleNamespace(output_token_ids=[1]), - "r2": SimpleNamespace(output_token_ids=[2]), - } - runner.execute_model_state = ( - SimpleNamespace(total_num_scheduled_tokens=2, num_scheduled_tokens={"r1": 1, "r2": 1}), - None, - None, - None, - torch.zeros((2, 4), dtype=torch.float32), - torch.zeros((2, 4), dtype=torch.float32), - None, - None, - None, - None, - None, - ) - runner._sample = lambda logits, spec_decode_metadata: SamplerOutput( - sampled_token_ids=torch.tensor([[1], [2]], dtype=torch.int32), - logprobs_tensors=None, - ) - runner.max_model_len = 4 - runner.query_start_loc = SimpleNamespace(cpu=torch.tensor([0, 1], dtype=torch.int32)) - runner._omni_num_scheduled_tokens_np = np.array([1, 1], dtype=np.int32) - runner.vllm_config = SimpleNamespace(model_config=SimpleNamespace(engine_output_type="omni")) - runner.model_config = SimpleNamespace(enable_return_routed_experts=False) - runner.supports_mm_inputs = False - runner.kv_connector_output = None - runner.eplb_step = lambda: None - runner.finalize_kv_connector = lambda: None - return runner - - def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): # Patch the module-level `set_forward_context` symbol used inside # OmniGPUModelRunner._talker_mtp_forward. @@ -334,65 +250,3 @@ def test_maybe_attach_mimo_audio_req_infos_no_req_state_returns_input(): # When no req_state, helper should be a no-op. assert result is req_infos - - -def test_sample_tokens_applies_postprocessed_tokens_before_bookkeeping(): - runner = _make_sample_tokens_runner(ReplaceSampledTokensModel()) - captured = {} - - def fake_bookkeeping( - self, - scheduler_output, - sampler_output, - logits, - hidden_states, - num_scheduled_tokens, - spec_decode_metadata, - ): - captured["sampled_token_ids"] = sampler_output.sampled_token_ids.clone() - raise StopAfterBookkeepingError - - runner._bookkeeping_sync = fake_bookkeeping.__get__(runner, type(runner)) - - with pytest.raises(StopAfterBookkeepingError): - GPUARModelRunner.sample_tokens(runner, grammar_output=None) - - assert torch.equal(runner.model.observed_sampled_token_ids, torch.tensor([[1], [2]], dtype=torch.int32)) - assert torch.equal(captured["sampled_token_ids"], torch.tensor([[11], [12]], dtype=torch.int32)) - - -def test_sample_tokens_passes_pending_updates_to_postprocess_without_committing_before_bookkeeping(): - runner = _make_sample_tokens_runner(OverlaySampledTokensModel()) - - def fake_collect(*args, **kwargs): - del args, kwargs - return {"r1": {"pending": 11}, "r2": {"pending": 22}} - - captured = {} - - def fake_bookkeeping( - self, - scheduler_output, - sampler_output, - logits, - hidden_states, - num_scheduled_tokens, - spec_decode_metadata, - ): - del scheduler_output, sampler_output, logits, hidden_states, num_scheduled_tokens, spec_decode_metadata - captured["buffer_before_bookkeeping"] = { - req_id: dict(info) for req_id, info in self.model_intermediate_buffer.items() - } - raise StopAfterBookkeepingError - - runner._collect_additional_information_updates = fake_collect - runner._bookkeeping_sync = fake_bookkeeping.__get__(runner, type(runner)) - - with pytest.raises(StopAfterBookkeepingError): - GPUARModelRunner.sample_tokens(runner, grammar_output=None) - - assert runner.model.observed_buffer == { - "r1": {"token": 1, "pending": 11}, - "r2": {"token": 2, "pending": 22}, - } - assert captured["buffer_before_bookkeeping"] == {"r1": {"token": 1}, "r2": {"token": 2}} diff --git a/vllm_omni/model_executor/models/funaudiochat/funaudiochat_code2wav.py b/vllm_omni/model_executor/models/funaudiochat/funaudiochat_code2wav.py index 06bff4a380d..b373babaed6 100644 --- a/vllm_omni/model_executor/models/funaudiochat/funaudiochat_code2wav.py +++ b/vllm_omni/model_executor/models/funaudiochat/funaudiochat_code2wav.py @@ -145,7 +145,9 @@ def _build_decode_tokens( raw_id_batches = [torch.empty((0,), dtype=torch.long)] token_batches = [ - raw_ids.reshape(1, -1).to(dtype=torch.long, device=self.vllm_config.device_config.device).clamp_( + raw_ids.reshape(1, -1) + .to(dtype=torch.long, device=self.vllm_config.device_config.device) + .clamp_( min=0, max=self._max_codec_token_id, ) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 4ecbfdc7082..73daba31527 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -462,20 +462,6 @@ def _collect_pending_postprocess_updates( traceback.print_exc() return updates - def _collect_additional_information_updates( - self, - hidden_states: torch.Tensor, - multimodal_outputs: object, - num_scheduled_tokens_np: np.ndarray, - scheduler_output: SchedulerOutput | None = None, - ) -> dict[str, dict[str, Any]]: - del scheduler_output - return self._collect_pending_postprocess_updates( - hidden_states, - multimodal_outputs, - num_scheduled_tokens_np, - ) - def _build_overlay_intermediate_buffer_local( self, pending_updates: dict[str, dict[str, Any]], @@ -565,8 +551,10 @@ def sample_tokens( dtype=np.int32, ) - pending_intermediate_updates = self._collect_additional_information_updates( - hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output + pending_intermediate_updates = self._collect_pending_postprocess_updates( + hidden_states, + multimodal_outputs, + num_scheduled_tokens_np, ) overlay_intermediate_buffer = None postprocess_hook = getattr(self.model, "postprocess_sampled_tokens", None) diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 05785d7aa6c..49fc7d3f018 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -265,9 +265,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - # When spec decode is enabled, delay clearing connector metadata - # until after draft model runs in sample_tokens. - clear_kv_metadata = self.speculative_config is None + # When spec decode is enabled, delay connector finalization until + # after draft model runs in sample_tokens. + defer_finalize = self.speculative_config is not None with ( set_forward_context( attn_metadata, @@ -280,9 +280,7 @@ def execute_model( slot_mapping=slot_mappings, # OMNI: required for KV cache operations ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output( - scheduler_output, clear_metadata=clear_kv_metadata - ) as kv_connector_output, + self.maybe_get_kv_connector_output(scheduler_output, defer_finalize=defer_finalize) as kv_connector_output, ): outputs = self._run_generation_model( input_ids=input_ids, @@ -353,7 +351,7 @@ def sample_tokens( # Clear KV connector metadata after draft model runs (if spec decode). if self.speculative_config is not None: - self.clear_kv_connector_metadata() + self.finalize_kv_connector() pooler_output: list[object] = [] if isinstance(multimodal_outputs, torch.Tensor): diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 3238f1108c7..e76b2076a3c 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1158,15 +1158,8 @@ def _preprocess( model_kwargs = self._init_model_kwargs() input_ids = self.input_ids.gpu[:num_input_tokens] elif getattr(self.model, "has_preprocess", False): - # Raw-token preprocess stages should see input_ids first, then - # materialize embeddings back into the pre-allocated buffer before - # the final forward. input_ids = self.input_ids.gpu[:num_input_tokens] - inputs_embeds = ( - None - if getattr(self.model, "requires_raw_input_tokens", False) - else self.inputs_embeds.gpu[:num_input_tokens] - ) + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] model_kwargs = self._init_model_kwargs() else: # For text-only models, we use token ids as input. @@ -1251,10 +1244,11 @@ def _preprocess( input_embeds=embed_slice, **req_infos, ) - if req_embeds is None: - raise RuntimeError( - "Model preprocess must return req_embeds when has_preprocess=True, " - f"but got None for request {req_id}." + if inputs_embeds is None: + inputs_embeds = torch.empty( + (preprocess_input_ids.shape[0], req_embeds.shape[-1]), + device=req_embeds.device, + dtype=req_embeds.dtype, ) if hasattr(self.model, "talker_mtp") and span_len == 1: @@ -1269,17 +1263,14 @@ def _preprocess( # TODO(Peiqi): the merge stage could move out from the critical path self._merge_additional_information_update(req_id, update_dict) - if inputs_embeds is None: - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] - if inputs_embeds is not None: - seg_len = min(span_len, req_embeds.shape[0]) - inputs_embeds[s : s + seg_len] = req_embeds[:seg_len] + seg_len = min(span_len, req_embeds.shape[0]) + inputs_embeds[s : s + seg_len] = req_embeds[:seg_len] if ( input_ids is not None and isinstance(req_input_ids, torch.Tensor) - and req_input_ids.numel() == span_len + and req_input_ids.numel() == seg_len ): - input_ids[s : s + span_len] = req_input_ids + input_ids[s : s + seg_len] = req_input_ids # run talker mtp decode if hasattr(self.model, "talker_mtp"): @@ -1363,18 +1354,18 @@ def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None: if hasattr(self, "model") and hasattr(self.model, "gpu_resident_buffer_keys"): gpu_keys = self.model.gpu_resident_buffer_keys existing = self.model_intermediate_buffer.setdefault(req_id, {}) - for key, value in upd.items(): - if isinstance(value, torch.Tensor): - if key in gpu_keys: - existing[key] = value.detach().clone() + for k, v in upd.items(): + if isinstance(v, torch.Tensor): + if k in gpu_keys: + existing[k] = v.detach().clone() else: - existing[key] = value.detach().to("cpu").contiguous() - elif isinstance(value, list): - existing[key] = [ - (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in value + existing[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + existing[k] = [ + (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v ] else: - existing[key] = value + existing[k] = v # Backward compatible: mirror to old setattr location setattr(req_state, "additional_information_cpu", existing) From f0fb323820139e4b96808a763360a2a769621794 Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Fri, 13 Mar 2026 14:04:21 +0800 Subject: [PATCH 08/14] Address latest FunAudioChat review fixes Signed-off-by: mayufeng --- docs/models/supported_models.md | 2 ++ .../entrypoints/test_funaudiochat_contrib.py | 27 +++++++++++++++++++ ...unaudiochat_s2s.yaml => funaudiochat.yaml} | 5 ++++ vllm_omni/worker/gpu_ar_model_runner.py | 12 +++++---- .../worker/gpu_generation_model_runner.py | 12 +++++---- 5 files changed, 48 insertions(+), 10 deletions(-) rename vllm_omni/model_executor/stage_configs/{funaudiochat_s2s.yaml => funaudiochat.yaml} (94%) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 3fe0021f936..9d2d72c23a9 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -53,6 +53,8 @@ th { |`Flux2Pipeline` | FLUX.2-dev | `black-forest-labs/FLUX.2-dev` | |`FishSpeechSlowARForConditionalGeneration` | Fish Speech S2 Pro | `fishaudio/s2-pro` | +For FunAudioChat S2S with local checkpoints, override both stage `engine_args.model` paths in the stage config instead of using the default HF repo IDs. + ## List of Supported Models for NPU diff --git a/tests/entrypoints/test_funaudiochat_contrib.py b/tests/entrypoints/test_funaudiochat_contrib.py index 7ca70f0eebf..b1d746a1bb9 100644 --- a/tests/entrypoints/test_funaudiochat_contrib.py +++ b/tests/entrypoints/test_funaudiochat_contrib.py @@ -5,8 +5,10 @@ from types import SimpleNamespace import pytest +import yaml from vllm_omni.engine.arg_utils import _resolve_bundled_hf_config_path +from vllm_omni.entrypoints import utils as entrypoint_utils from vllm_omni.entrypoints.omni import OmniBase pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -36,3 +38,28 @@ def test_get_stage_model_falls_back_when_stage_override_missing(): stage = SimpleNamespace(engine_args=SimpleNamespace()) assert OmniBase._get_stage_model(stage, "fallback-model") == "fallback-model" + + +def test_resolve_model_config_path_detects_funaudiochat_default_yaml(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + entrypoint_utils, + "get_config", + lambda model, trust_remote_code=True: SimpleNamespace(model_type="funaudiochat"), + ) + + resolved = entrypoint_utils.resolve_model_config_path("dummy-funaudiochat-model") + + assert resolved is not None + assert resolved.endswith("vllm_omni/model_executor/stage_configs/funaudiochat.yaml") + + +def test_funaudiochat_default_stage_config_limits_audio_profile_and_keeps_audio_towers(): + config_path = ( + Path(__file__).resolve().parents[2] / "vllm_omni" / "model_executor" / "stage_configs" / "funaudiochat.yaml" + ) + config = yaml.safe_load(config_path.read_text()) + stage0_engine_args = config["stage_args"][0]["engine_args"] + + assert "language_model_only" not in stage0_engine_args + assert stage0_engine_args["hf_overrides"]["audio_config"]["max_source_positions"] == 100 + assert stage0_engine_args["limit_mm_per_prompt"]["audio"] == 1 diff --git a/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml b/vllm_omni/model_executor/stage_configs/funaudiochat.yaml similarity index 94% rename from vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml rename to vllm_omni/model_executor/stage_configs/funaudiochat.yaml index 07a8e84eea6..efe067a0051 100644 --- a/vllm_omni/model_executor/stage_configs/funaudiochat_s2s.yaml +++ b/vllm_omni/model_executor/stage_configs/funaudiochat.yaml @@ -8,6 +8,11 @@ stage_args: model: FunAudioLLM/Fun-Audio-Chat-8B model_stage: s2s model_arch: FunAudioChatForConditionalGeneration + hf_overrides: + audio_config: + max_source_positions: 100 + limit_mm_per_prompt: + audio: 1 worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler trust_remote_code: true diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 73daba31527..edb75f992f5 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -279,9 +279,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - # When spec decode is enabled, delay connector finalization until - # after draft model runs in sample_tokens. - defer_finalize = self.speculative_config is not None + # When spec decode is enabled, delay clearing connector metadata + # until after draft model runs in sample_tokens. + clear_kv_metadata = self.speculative_config is None with ( set_forward_context( attn_metadata, @@ -294,7 +294,9 @@ def execute_model( slot_mapping=slot_mappings, # OMNI: required for KV cache operations ), record_function_or_nullcontext("gpu_model_runner: forward"), - self.maybe_get_kv_connector_output(scheduler_output, defer_finalize=defer_finalize) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, clear_metadata=clear_kv_metadata + ) as kv_connector_output, ): model_output = self._model_forward( input_ids=input_ids, @@ -657,7 +659,7 @@ def propose_draft_token_ids(sampled_token_ids): # This was deferred from target model forward to allow draft model # to also save its KV cache. if self.speculative_config is not None: - self.finalize_kv_connector() + self.clear_kv_connector_metadata() with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 49fc7d3f018..05785d7aa6c 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -265,9 +265,9 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - # When spec decode is enabled, delay connector finalization until - # after draft model runs in sample_tokens. - defer_finalize = self.speculative_config is not None + # When spec decode is enabled, delay clearing connector metadata + # until after draft model runs in sample_tokens. + clear_kv_metadata = self.speculative_config is None with ( set_forward_context( attn_metadata, @@ -280,7 +280,9 @@ def execute_model( slot_mapping=slot_mappings, # OMNI: required for KV cache operations ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output, defer_finalize=defer_finalize) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, clear_metadata=clear_kv_metadata + ) as kv_connector_output, ): outputs = self._run_generation_model( input_ids=input_ids, @@ -351,7 +353,7 @@ def sample_tokens( # Clear KV connector metadata after draft model runs (if spec decode). if self.speculative_config is not None: - self.finalize_kv_connector() + self.clear_kv_connector_metadata() pooler_output: list[object] = [] if isinstance(multimodal_outputs, torch.Tensor): From 600f45bd90832e81fa8999a348f14301c4769a35 Mon Sep 17 00:00:00 2001 From: "ramos.ma" Date: Fri, 13 Mar 2026 17:05:38 +0800 Subject: [PATCH 09/14] Move FunAudioChat local checkpoint note to stage config Signed-off-by: ramos.ma --- docs/models/supported_models.md | 3 --- vllm_omni/model_executor/stage_configs/funaudiochat.yaml | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9d2d72c23a9..33037fa515e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -53,9 +53,6 @@ th { |`Flux2Pipeline` | FLUX.2-dev | `black-forest-labs/FLUX.2-dev` | |`FishSpeechSlowARForConditionalGeneration` | Fish Speech S2 Pro | `fishaudio/s2-pro` | -For FunAudioChat S2S with local checkpoints, override both stage `engine_args.model` paths in the stage config instead of using the default HF repo IDs. - - ## List of Supported Models for NPU