From ee5e2cc856e3e51b2610693c19df5549375d799f Mon Sep 17 00:00:00 2001 From: "leo.yang" Date: Mon, 1 Jun 2026 12:18:57 +0200 Subject: [PATCH 1/2] [Cleanup] Remove dead/stale modules with no repo references Remove six files confirmed unreferenced across all Python imports, __init__.py exports, dynamic import strings, and YAML configs: - vllm_omni/assets/video.py - vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py - vllm_omni/distributed/kv_transfer/monkey_patch.py - vllm_omni/model_executor/models/qwen3_omni/qwen3_moe.py - vllm_omni/model_executor/models/moss_tts_nano/configuration_moss_tts_nano.py - vllm_omni/model_executor/stage_input_processors/omnivoice.py Part of codebase cleanup audit #4009. Signed-off-by: leo.yang --- vllm_omni/assets/video.py | 16 -- .../data_modules/daily_omni_text_audio.py | 255 ------------------ .../distributed/kv_transfer/monkey_patch.py | 193 ------------- .../configuration_moss_tts_nano.py | 60 ----- .../models/qwen3_omni/qwen3_moe.py | 173 ------------ .../stage_input_processors/omnivoice.py | 34 --- 6 files changed, 731 deletions(-) delete mode 100644 vllm_omni/assets/video.py delete mode 100644 vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py delete mode 100644 vllm_omni/distributed/kv_transfer/monkey_patch.py delete mode 100644 vllm_omni/model_executor/models/moss_tts_nano/configuration_moss_tts_nano.py delete mode 100644 vllm_omni/model_executor/models/qwen3_omni/qwen3_moe.py delete mode 100644 vllm_omni/model_executor/stage_input_processors/omnivoice.py diff --git a/vllm_omni/assets/video.py b/vllm_omni/assets/video.py deleted file mode 100644 index 6a5f3204a91..00000000000 --- a/vllm_omni/assets/video.py +++ /dev/null @@ -1,16 +0,0 @@ -import numpy as np -from vllm.assets.video import VideoAsset -from vllm.multimodal.media.audio import load_audio - - -def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndarray: - """This function extracts the audio from a video file path and returns the audio as a numpy array. - Args: - path: The path to the video file. - Returns: - The audio as a numpy array. - """ - if not path: - path = VideoAsset(name="baby_reading").video_path - audio_signal, sr = load_audio(path, sr=sampling_rate) - return audio_signal diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py b/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py deleted file mode 100644 index 69fbe026bd8..00000000000 --- a/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py +++ /dev/null @@ -1,255 +0,0 @@ -"""Daily-Omni: optional consistency check between text stream and generated speech. - -The benchmark MCQ accuracy uses ``generated_text`` only. When the omni server also -streams ``modality=audio`` (TTS), this module can transcribe the concatenated WAV -with Whisper and compare the inferred option letter to the one parsed from text. - -Requires ``openai-whisper`` (``pip install openai-whisper``). Enable via env -``DAILY_OMNI_TEXT_AUDIO_CONSISTENCY=1`` or CLI ``--daily-omni-text-audio-consistency``. - -Whisper model name defaults to ``tiny`` (override with ``DAILY_OMNI_WHISPER_MODEL``). -""" - -from __future__ import annotations - -import logging -import os -import re -import threading -from typing import Any - -from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniSampleRequest -from vllm_omni.benchmarks.data_modules.daily_omni_eval import extract_predicted_choice - -logger = logging.getLogger(__name__) - -_whisper_model = None -_whisper_model_name: str | None = None -_whisper_lock = threading.Lock() - - -def env_text_audio_check_enabled() -> bool: - return os.environ.get("DAILY_OMNI_TEXT_AUDIO_CONSISTENCY", "").lower() in ( - "1", - "true", - "yes", - ) - - -def extract_choice_from_asr_transcript(transcript: str) -> str | None: - """Parse A–D from ASR text; extends :func:`extract_predicted_choice` with spoken Chinese phrases.""" - c = extract_predicted_choice(transcript) - if c: - return c - t = transcript or "" - for pat in ( - r"(?i)选项\s*([ABCD])\b", - r"(?i)选\s*([ABCD])\b", - r"(?i)答案\s*是\s*([ABCD])\b", - r"(?i)答案\s*([ABCD])\b", - ): - m = re.search(pat, t) - if m: - return m.group(1).upper() - return None - - -def _get_whisper_model(model_name: str): - global _whisper_model, _whisper_model_name - with _whisper_lock: - if _whisper_model is None or _whisper_model_name != model_name: - import whisper - - logger.warning( - "Loading Whisper model %r for Daily-Omni text/audio consistency (one-time)...", - model_name, - ) - _whisper_model = whisper.load_model(model_name) - _whisper_model_name = model_name - return _whisper_model - - -def transcribe_wav_bytes( - wav_bytes: bytes, - *, - language: str | None = None, - model_name: str | None = None, -) -> tuple[str | None, str | None]: - """Transcribe WAV bytes. Returns ``(transcript, error)`` — one of them is set. - - Args: - wav_bytes: RIFF WAV file bytes. - language: Optional Whisper language code (e.g. ``en``, ``zh``); improves accuracy/latency. - model_name: Override model id; else ``DAILY_OMNI_WHISPER_MODEL`` or ``tiny``. - """ - if not wav_bytes: - return None, "empty_wav" - if model_name is None or not str(model_name).strip(): - model_name = os.environ.get("DAILY_OMNI_WHISPER_MODEL") or "tiny" - model_name = str(model_name).strip() or "tiny" - path: str | None = None - try: - import tempfile - - model = _get_whisper_model(model_name) - fd, path = tempfile.mkstemp(suffix=".wav") - with os.fdopen(fd, "wb") as fp: - fp.write(wav_bytes) - kwargs: dict = {} - if language: - kwargs["language"] = language - result = model.transcribe(path, **kwargs) - text = (result.get("text") or "").strip() - return (text if text else None), None - except ImportError: - return None, "openai-whisper is not installed (pip install openai-whisper)" - except Exception as e: - return None, str(e)[:500] - finally: - if path: - try: - os.unlink(path) - except OSError: - pass - - -def compute_daily_omni_text_audio_consistency_metrics( - input_requests: list[Any], - outputs: list[Any], - *, - include_per_item: bool = False, -) -> dict[str, Any] | None: - """Compare option letter from ``generated_text`` vs Whisper transcript of output audio. - - Only considers requests where ``outputs[i]`` has ``generated_audio_wav_bytes`` set - (populated by the omni benchmark when TA check is enabled). - """ - if not input_requests or len(input_requests) != len(outputs): - return None - if not all(isinstance(r, DailyOmniSampleRequest) for r in input_requests): - return None - - ta_no_wav = 0 - ta_asr_failed = 0 - ta_text_unparsed = 0 - ta_audio_unparsed = 0 - ta_consistent = 0 - ta_mismatch = 0 - ta_both_parsed = 0 - items: list[dict[str, Any]] = [] - - for req, out in zip(input_requests, outputs, strict=True): - assert isinstance(req, DailyOmniSampleRequest) - rid = req.request_id - if not getattr(out, "success", False): - if include_per_item: - items.append( - { - "request_id": rid, - "skipped": True, - "reason": "request_not_success", - } - ) - continue - - wav = getattr(out, "generated_audio_wav_bytes", None) - if not wav: - ta_no_wav += 1 - if include_per_item: - items.append( - { - "request_id": rid, - "skipped": False, - "reason": "no_output_audio", - "text_choice": extract_predicted_choice(getattr(out, "generated_text", "") or ""), - } - ) - continue - - transcript, asr_err = transcribe_wav_bytes(wav) - if asr_err: - ta_asr_failed += 1 - if include_per_item: - items.append( - { - "request_id": rid, - "asr_error": asr_err, - "text_choice": extract_predicted_choice(getattr(out, "generated_text", "") or ""), - } - ) - continue - - text_choice = extract_predicted_choice(getattr(out, "generated_text", "") or "") - audio_choice = extract_choice_from_asr_transcript(transcript or "") - - if text_choice is None: - ta_text_unparsed += 1 - if audio_choice is None: - ta_audio_unparsed += 1 - - if text_choice is not None and audio_choice is not None: - ta_both_parsed += 1 - if text_choice == audio_choice: - ta_consistent += 1 - else: - ta_mismatch += 1 - - if include_per_item: - consistent: bool | None - if text_choice is None or audio_choice is None: - consistent = None - else: - consistent = text_choice == audio_choice - items.append( - { - "request_id": rid, - "text_choice": text_choice, - "audio_choice": audio_choice, - "asr_transcript": (transcript or "")[:500], - "text_audio_consistent": consistent, - } - ) - - comparable = ta_consistent + ta_mismatch - rate = (ta_consistent / comparable) if comparable else None - - out: dict[str, Any] = { - "daily_omni_ta_enabled": True, - "daily_omni_ta_no_output_audio": ta_no_wav, - "daily_omni_ta_asr_failed": ta_asr_failed, - "daily_omni_ta_text_unparsed": ta_text_unparsed, - "daily_omni_ta_audio_unparsed": ta_audio_unparsed, - "daily_omni_ta_both_parsed": ta_both_parsed, - "daily_omni_ta_consistent": ta_consistent, - "daily_omni_ta_mismatch": ta_mismatch, - "daily_omni_ta_consistency_rate": rate, - } - if include_per_item: - out["daily_omni_ta_items"] = items - return out - - -def print_daily_omni_text_audio_summary(metrics: dict[str, Any]) -> None: - if not metrics.get("daily_omni_ta_enabled"): - return - print("{s:{c}^{n}}".format(s=" Daily-Omni text vs audio (ASR) ", n=50, c="=")) - print("{:<40} {:<10}".format("No output audio captured:", metrics.get("daily_omni_ta_no_output_audio", 0))) - print("{:<40} {:<10}".format("ASR failed:", metrics.get("daily_omni_ta_asr_failed", 0))) - print("{:<40} {:<10}".format("Both text+audio letter parsed:", metrics.get("daily_omni_ta_both_parsed", 0))) - print("{:<40} {:<10}".format("Consistent (same letter):", metrics.get("daily_omni_ta_consistent", 0))) - print("{:<40} {:<10}".format("Mismatch:", metrics.get("daily_omni_ta_mismatch", 0))) - r = metrics.get("daily_omni_ta_consistency_rate") - if r is not None: - print("{:<40} {:<10.4f}".format("Consistency rate (of both parsed):", r)) - print( - "{:<40} {:<10}".format( - "Text unparsed (among w/ audio):", - metrics.get("daily_omni_ta_text_unparsed", 0), - ) - ) - print( - "{:<40} {:<10}".format( - "Audio unparsed (among w/ audio):", - metrics.get("daily_omni_ta_audio_unparsed", 0), - ) - ) diff --git a/vllm_omni/distributed/kv_transfer/monkey_patch.py b/vllm_omni/distributed/kv_transfer/monkey_patch.py deleted file mode 100644 index 5c3b97755b8..00000000000 --- a/vllm_omni/distributed/kv_transfer/monkey_patch.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Monkey-patch vLLM's MooncakeConnector to fix request-ID mismatch in PD disaggregation. - -vLLM's InputProcessor appends a random suffix to each request ID. The prefill -engine stores KV under its suffix, but the decode engine generates a different -suffix. This patch threads ``remote_request_id`` through ``kv_transfer_params`` -so the decode side references the correct KV entry. -""" - -from __future__ import annotations - -import importlib -import logging -import sys -from dataclasses import dataclass -from typing import Any - -logger = logging.getLogger(__name__) - -_patched: bool = False - - -@dataclass -class PatchedRecvReqMeta: - """Receive-request metadata carrying the prefill engine's request ID.""" - - request_id: str - remote_request_id: str - local_block_ids: list[int] - kv_transfer_params: dict[str, Any] - - -def _import_mooncake_module(): - """Import MooncakeConnector module, supporting both vLLM >=0.16 and older.""" - for mod_path in ( - "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector", - "vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector", - ): - try: - return importlib.import_module(mod_path) - except (ImportError, ModuleNotFoundError): - continue - return None - - -def _create_patched_mooncake_connector(): - """Return a subclass of MooncakeConnector with remote_request_id support.""" - _mc_mod = _import_mooncake_module() - if _mc_mod is None: - raise ImportError("Cannot import MooncakeConnector from upstream vLLM") - _OriginalMooncakeConnector = _mc_mod.MooncakeConnector - - class PatchedMooncakeConnector(_OriginalMooncakeConnector): - """Fixes request-ID mismatch in PD disaggregation by injecting - remote_request_id on the prefill side and using it for KV lookup - on the decode side. - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.remote_to_local_req: dict[str, str] = {} - logger.info("[PatchedMooncakeConnector] Initialized") - - def request_finished( - self, - request: Any, - block_ids: list[int], - ) -> tuple[bool, dict[str, Any] | None]: - result = super().request_finished(request, block_ids) - - if isinstance(result, tuple) and len(result) == 2: - delay_free, kv_params = result - else: - delay_free, kv_params = False, result - - # Normalise _reqs_need_send values - req_id = getattr(request, "request_id", None) - if req_id and hasattr(self, "_reqs_need_send"): - entry = self._reqs_need_send.get(req_id) - if isinstance(entry, tuple) and len(entry) == 2: - self._reqs_need_send[req_id] = entry[1] - - # Inject remote_request_id into kv_transfer_params - if kv_params is not None and isinstance(kv_params, dict): - kv_params["remote_request_id"] = req_id or "NOT_SET" - if hasattr(self, "side_channel_host"): - kv_params.setdefault("remote_host", self.side_channel_host) - if hasattr(self, "side_channel_port"): - kv_params.setdefault("remote_port", self.side_channel_port) - - return delay_free, kv_params - - def add_new_req( - self, - request_id: str, - local_block_ids: list[int], - kv_transfer_params: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - super().add_new_req(request_id, local_block_ids, kv_transfer_params, **kwargs) - - kv_transfer_params = kv_transfer_params or {} - load_remote_cache = kv_transfer_params.get( - "do_remote_prefill", - kv_transfer_params.get("load_remote_cache", False), - ) - - if load_remote_cache: - remote_request_id = kv_transfer_params.get("remote_request_id", request_id) - meta = PatchedRecvReqMeta( - request_id=request_id, - remote_request_id=remote_request_id, - local_block_ids=local_block_ids, - kv_transfer_params=kv_transfer_params, - ) - if not hasattr(self, "_reqs_need_recv"): - self._reqs_need_recv = {} - self._reqs_need_recv[request_id] = meta - - def group_kv_pull(self, metadata: Any | None = None) -> None: - """Use remote_request_id as ZMQ lookup key via save-patch-restore.""" - if not hasattr(self, "_reqs_need_recv") or not self._reqs_need_recv: - return - - original_recv = self._reqs_need_recv.copy() - patched_recv: dict[str, Any] = {} - - for local_id, meta in original_recv.items(): - if isinstance(meta, PatchedRecvReqMeta): - remote_id = meta.remote_request_id - self.remote_to_local_req[remote_id] = local_id - patched_meta = type(meta)( - request_id=remote_id, - remote_request_id=remote_id, - local_block_ids=meta.local_block_ids, - kv_transfer_params=meta.kv_transfer_params, - ) - patched_recv[remote_id] = patched_meta - else: - patched_recv[local_id] = meta - - self._reqs_need_recv = patched_recv - super().group_kv_pull(metadata) - - # Restore unconsumed entries to original local keys - for remote_id, local_id in list(self.remote_to_local_req.items()): - if remote_id in self._reqs_need_recv: - entry = self._reqs_need_recv.pop(remote_id) - self._reqs_need_recv[local_id] = original_recv.get(local_id, entry) - - def receive_kv(self, path: Any = None, req_blocks: Any = None) -> Any: - result = super().receive_kv(path, req_blocks) - - if self.remote_to_local_req: - completed = [ - rid - for rid, lid in self.remote_to_local_req.items() - if not hasattr(self, "_reqs_need_recv") or lid not in self._reqs_need_recv - ] - for remote_id in completed: - self.remote_to_local_req.pop(remote_id, None) - - return result - - PatchedMooncakeConnector.__qualname__ = _OriginalMooncakeConnector.__qualname__ - - return PatchedMooncakeConnector - - -def apply_mooncake_connector_patch() -> bool: - """Replace vLLM's MooncakeConnector with the patched version.""" - global _patched - if _patched: - return True - - _mc_module = _import_mooncake_module() - if _mc_module is None: - logger.warning("[monkey_patch] Cannot import MooncakeConnector — patch NOT applied.") - return False - - _OriginalClass = _mc_module.MooncakeConnector - - PatchedClass = _create_patched_mooncake_connector() - - _mc_module.MooncakeConnector = PatchedClass - # Snapshot sys.modules: hasattr() may trigger lazy submodule imports that - # mutate sys.modules during iteration. - for _, module in list(sys.modules.items()): - if hasattr(module, "MooncakeConnector") and module.MooncakeConnector is _OriginalClass: - module.MooncakeConnector = PatchedClass - - _patched = True - logger.info("[monkey_patch] MooncakeConnector patch applied") - return True diff --git a/vllm_omni/model_executor/models/moss_tts_nano/configuration_moss_tts_nano.py b/vllm_omni/model_executor/models/moss_tts_nano/configuration_moss_tts_nano.py deleted file mode 100644 index a5a9f25acd3..00000000000 --- a/vllm_omni/model_executor/models/moss_tts_nano/configuration_moss_tts_nano.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Configuration for MOSS-TTS-Nano in vLLM-Omni single-stage pipeline.""" - -from transformers.configuration_utils import PretrainedConfig - - -class MossTTSNanoConfig(PretrainedConfig): - """Config for MOSS-TTS-Nano (OpenMOSS-Team/MOSS-TTS-Nano). - - The model is a Global-Local GPT2-based AR LM (0.1B) paired with the - MOSS-Audio-Tokenizer-Nano codec. Both components run in a single - vLLM-Omni generation stage. - - Relevant fields from config.json: - gpt2_config – backbone config (n_layer=12, n_embd=768, n_head=12, ...) - n_vq – number of RVQ codebooks (16) - audio_vocab_size – per-codebook vocabulary size (1024) - audio_tokenizer_pretrained_name_or_path – HF hub path for codec model - """ - - model_type = "moss_tts_nano" - - def __init__(self, **kwargs): - gpt2_cfg = kwargs.pop("gpt2_config", None) or {} - if hasattr(gpt2_cfg, "to_dict"): - gpt2_cfg = gpt2_cfg.to_dict() - - super().__init__(**kwargs) - - # --- GPT2 backbone parameters (exposed at top level for vLLM) --- - self.hidden_size = gpt2_cfg.get("n_embd", 768) - self.num_hidden_layers = gpt2_cfg.get("n_layer", 12) - self.num_attention_heads = gpt2_cfg.get("n_head", 12) - self.num_key_value_heads = self.num_attention_heads # no GQA in GPT2 - self.head_dim = self.hidden_size // self.num_attention_heads - self.vocab_size = gpt2_cfg.get("vocab_size", 16384) - self.max_position_embeddings = gpt2_cfg.get("n_positions", 32768) - self.intermediate_size = gpt2_cfg.get("n_inner", self.hidden_size * 4) - - # --- Audio codec parameters --- - self.n_vq: int = getattr(self, "n_vq", 16) - self.audio_vocab_size: int = getattr(self, "audio_vocab_size", 1024) - self.audio_start_token_id: int = getattr(self, "audio_start_token_id", 6) - self.audio_end_token_id: int = getattr(self, "audio_end_token_id", 7) - self.audio_user_slot_token_id: int = getattr(self, "audio_user_slot_token_id", 8) - self.audio_assistant_slot_token_id: int = getattr(self, "audio_assistant_slot_token_id", 9) - self.audio_tokenizer_pretrained_name_or_path: str = getattr( - self, - "audio_tokenizer_pretrained_name_or_path", - "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano", - ) - self.audio_tokenizer_sample_rate: int = getattr(self, "audio_tokenizer_sample_rate", 48000) - - # vLLM requires speculative_config to be absent or None - self.speculative_config = None - - def get_text_config(self, **kwargs): - """Return self so vLLM uses our top-level config.""" - return self diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_moe.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_moe.py deleted file mode 100644 index 9332363136d..00000000000 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_moe.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import annotations - -import torch -import torch.nn.functional as F -from torch import nn -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.models.qwen3_moe import ( - Qwen3MoeDecoderLayer, - Qwen3MoeMLP, - Qwen3MoeModel, # as _BaseQwen3MoeModel, -) -from vllm.model_executor.models.qwen3_moe import ( - Qwen3MoeForCausalLM as _BaseQwen3MoeForCausalLM, -) -from vllm.model_executor.models.utils import ( - PPMissingLayer, - maybe_prefix, -) - -logger = init_logger(__name__) - - -# Individual expert MoE block using Qwen3MoeMLP instead of FusedMoE -class Qwen3OmniMoeSparseMoeBlock(nn.Module): - """Sparse MoE block using individual Qwen3MoeMLP experts instead of FusedMoE.""" - - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): - super().__init__() - - config = vllm_config.model_config.hf_text_config - quant_config = vllm_config.quant_config - - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - self.hidden_size = config.hidden_size - - # Create individual expert MLPs - self.experts = nn.ModuleList( - [ - Qwen3MoeMLP( - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=f"{prefix}.experts.{i}", - ) - for i in range(self.num_experts) - ] - ) - - # Router for expert selection - from vllm.model_executor.layers.linear import ReplicatedLinear - - self.gate = ReplicatedLinear( - config.hidden_size, config.num_experts, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate" - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Forward pass using individual experts.""" - # Handle 3D inputs (batch, seq_len, hidden_size) by reshaping to 2D - orig_shape = hidden_states.shape - if hidden_states.dim() == 3: - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - elif hidden_states.dim() == 2: - num_tokens, hidden_dim = hidden_states.shape - elif hidden_states.dim() == 1: - hidden_states = hidden_states.unsqueeze(0) - num_tokens, hidden_dim = hidden_states.shape - else: - raise ValueError( - f"Qwen3OmniMoeSparseMoeBlock only supports 1D, 2D, or 3D inputs, got {hidden_states.dim()}D" - ) - - is_input_1d = len(orig_shape) == 1 - hidden_states = hidden_states.view(-1, hidden_dim) - - # Get router logits and select experts (matching transformers) - router_logits, _ = self.gate(hidden_states) - selected_experts, routing_weights = self._route_tokens(router_logits) - - # Forward through individual experts - final_hidden_states = self._forward_experts(hidden_states, selected_experts, routing_weights) - - # Reshape back to original shape - if is_input_1d: - return final_hidden_states.squeeze(0) - elif len(orig_shape) == 3: - # Reshape back to 3D (batch, seq_len, hidden_dim) - return final_hidden_states.view(orig_shape) - else: - return final_hidden_states - - def _route_tokens(self, router_logits: torch.Tensor): - """Route tokens to experts using top-k selection (matching transformers).""" - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights - - def _forward_experts( - self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor - ): - """Forward through individual experts (matching transformers implementation).""" - final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self.experts[expert_idx](current_state) * routing_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - - return final_hidden_states - - -class Qwen3MoeForCausalLM(_BaseQwen3MoeForCausalLM): - """Thin wrapper to swap in the patched `Qwen3MoeModel`.""" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Don't call super().__init__() to avoid duplicate layer registration. - nn.Module.__init__(self) - config = vllm_config.model_config.hf_text_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = Qwen3MoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head") - ) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors - - # Set MoE hyperparameters for individual experts - self.expert_weights = [] - - self.moe_layers: list[FusedMoE] = [] - example_layer = None - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - - assert isinstance(layer, Qwen3MoeDecoderLayer) - if isinstance(layer.mlp, FusedMoE): - example_layer = layer.mlp - self.moe_layers.append(layer.mlp) - - if example_layer is None: - raise RuntimeError("No Qwen3OmniMoe layer found in the model.layers.") - - self.num_moe_layers = len(self.moe_layers) - self.num_expert_groups = 1 - self.num_shared_experts = 0 - self.num_logical_experts = example_layer.n_logical_experts - self.num_physical_experts = example_layer.n_physical_experts - self.num_local_physical_experts = example_layer.n_local_physical_experts - self.num_routed_experts = example_layer.n_routed_experts - self.num_redundant_experts = example_layer.n_redundant_experts diff --git a/vllm_omni/model_executor/stage_input_processors/omnivoice.py b/vllm_omni/model_executor/stage_input_processors/omnivoice.py deleted file mode 100644 index ce863c13568..00000000000 --- a/vllm_omni/model_executor/stage_input_processors/omnivoice.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Inter-stage processor for OmniVoice: Generator → Decoder.""" - -from typing import Any - -from vllm.inputs import TextPrompt - -from vllm_omni.inputs.data import OmniTokensPrompt - - -def tokens2audio( - source_outputs: list[Any], - _prompt: OmniTokensPrompt | TextPrompt = None, - _requires_multimodal_data: bool = True, -): - """Build stage-1 (decoder) inputs from stage-0 (generator) outputs. - - Takes the 8-codebook audio tokens from the generator and packages - them for the HiggsAudioV2 decoder. - """ - source_output = source_outputs[0] - output = source_output.outputs[0] - - multi_modal_data = output.multimodal_output - if multi_modal_data is None: - raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}") - - # Pass audio_tokens from generator to decoder - engine_input = OmniTokensPrompt( - prompt_token_ids=output.cumulative_token_ids, - additional_information=multi_modal_data, - ) - return [engine_input] From f8412c7e7a6171c686c499ea43d9140536fa0532 Mon Sep 17 00:00:00 2001 From: "leo.yang" Date: Mon, 1 Jun 2026 18:06:08 +0200 Subject: [PATCH 2/2] fix: remove empty kv_transfer package after monkey_patch.py deletion Signed-off-by: leo.yang --- vllm_omni/distributed/kv_transfer/__init__.py | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 vllm_omni/distributed/kv_transfer/__init__.py diff --git a/vllm_omni/distributed/kv_transfer/__init__.py b/vllm_omni/distributed/kv_transfer/__init__.py deleted file mode 100644 index f8914a428ba..00000000000 --- a/vllm_omni/distributed/kv_transfer/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Patched KV transfer connectors for PD disaggregation. - -This package provides monkey-patched versions of vLLM's native KV transfer -connectors (e.g. MooncakeConnector) that fix the request-ID mismatch problem -in prefill-decode disaggregation. - -vLLM's ``InputProcessor.assign_request_id()`` appends a random 8-char suffix -to each request ID internally. The prefill engine stores KV under its own -suffix, but the decode engine generates a *different* suffix — so it can never -find the KV data. The patched connector threads the prefill engine's internal -``remote_request_id`` through ``kv_transfer_params`` so the decode side can -reference the correct KV entry. -"""