Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ dependencies = [
"soundfile>=0.13.1",
"cache-dit==1.1.8",
"tqdm>=4.66.0",
"mamba-ssm>=1.2.0",
"causal-conv1d>=1.2.0",
"decord>=0.6.0",
# "vllm==0.12.0", # TODO: fix the entrypoints overwrite problem
]

Expand Down
6 changes: 5 additions & 1 deletion vllm_omni/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .qwen3_omni import Qwen3OmniMoeForConditionalGeneration
from .hyperclovax_seed_omni import HyperCLOVAXSeedOmniForConditionalGeneration
from .registry import OmniModelRegistry # noqa: F401

__all__ = ["Qwen3OmniMoeForConditionalGeneration"]
__all__ = [
"Qwen3OmniMoeForConditionalGeneration",
"HyperCLOVAXSeedOmniForConditionalGeneration",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .hyperclovax_seed_omni import HyperCLOVAXSeedOmniForConditionalGeneration

__all__ = ["HyperCLOVAXSeedOmniForConditionalGeneration"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import torch
import torch.nn as nn
from typing import Optional, Union, Dict, Any, Iterable, Set, Tuple

from transformers import AutoModel, AutoConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.utils import init_vllm_registered_model, maybe_prefix, add_prefix_to_loaded_weights
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.v1.sample.sampler import Sampler

from vllm_omni.model_executor.models.output_templates import OmniOutput
from vllm_omni.model_executor.models.hyperclovax_seed_omni.hyperclovax_seed_omni_thinker import (
HyperCLOVAXSeedOmniThinkerForConditionalGeneration,
HyperCLOVAXSeedOmniThinkerMultiModalProcessor
)

logger = init_logger(__name__)

class HyperCLOVAXSeedOmniAudioDecoder(nn.Module):
"""
Wrapper for CosyVoice2 Audio Decoder.
Loads the actual model implementation from the Hugging Face repository using trust_remote_code.
"""
def __init__(self, config, model_path: Optional[str] = None):
super().__init__()
self.config = config

# Load the actual CosyVoice2 model from the HF repository
# This executes the remote code (cosyvoice.py) provided by the model authors.
try:
path_to_load = model_path if model_path else str(getattr(config, "_name_or_path", ""))
logger.info(f"Loading CosyVoice2 Audio Decoder from {path_to_load}...")

# Use AutoModel with trust_remote_code=True to load the custom model class
self.model = AutoModel.from_pretrained(
path_to_load,
config=config,
trust_remote_code=True
)
logger.info("Successfully loaded CosyVoice2 Audio Decoder.")

except Exception as e:
logger.error(f"Failed to load CosyVoice2 model via AutoModel: {e}")
logger.warning("Falling back to placeholder for compilation checks (Runtime will fail if not fixed).")
self.model = None

def forward(self, audio_tokens: torch.Tensor) -> torch.Tensor:
"""
Args:
audio_tokens: (Batch, Seq_Len) tensor of discrete audio tokens
Returns:
waveform: (Batch, Audio_Len) tensor of generated audio waveform
"""
if self.model is None:
# Runtime fail-safe if model loading failed
raise RuntimeError("CosyVoice2 model failed to load. Cannot generate audio.")

# Call the actual model's inference/decode method.
# The method name depends on the specific implementation in cosyvoice.py.
# Common patterns: model.decode(), model.inference(), or forward()

# Assuming standard forward or a specific decode method exposed by the remote code
# We try to detect the method or default to forward
if hasattr(self.model, "decode"):
return self.model.decode(audio_tokens)
elif hasattr(self.model, "inference"):
return self.model.inference(audio_tokens)
else:
# Default forward pass
return self.model(audio_tokens)

def load_weights(self, weights):
# AutoModel handles weight loading during init, so we might not need manual loading here
# unless vLLM requires it. For external models loaded via AutoModel, we return empty set
# to tell vLLM "we handled it".
return set()

@MULTIMODAL_REGISTRY.register_processor(
HyperCLOVAXSeedOmniThinkerMultiModalProcessor,
info=None,
dummy_inputs=None,
)
class HyperCLOVAXSeedOmniForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.config = vllm_config.model_config.hf_config
self.prefix = prefix

self.model_stage = vllm_config.model_config.model_stage
self.model_path = vllm_config.model_config.model # Original model path

self.thinker: Optional[HyperCLOVAXSeedOmniThinkerForConditionalGeneration] = None
self.code2wav: Optional[HyperCLOVAXSeedOmniAudioDecoder] = None
self.model: Optional[nn.Module] = None

if self.model_stage == "thinker":
# Initialize thinker model (LLM + Encoders)
self.thinker = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "thinker"),
hf_config=self.config,
architectures=["HyperCLOVAXSeedOmniThinkerForConditionalGeneration"],
)
self.model = self.thinker

elif self.model_stage == "code2wav":
# Initialize audio decoder (CosyVoice2 wrapper)
# We pass the full config and the model path to allow loading the remote code

# Check if there's a specific audio decoder config section
audio_config = getattr(self.config, "audio_decoder_config", self.config)

self.code2wav = HyperCLOVAXSeedOmniAudioDecoder(audio_config, model_path=self.model_path)
self.model = self.code2wav

else:
raise ValueError(f"Invalid model stage: {self.model_stage}. Supported: 'thinker', 'code2wav'")

self.sampler = Sampler()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors, OmniOutput]:

if self.model_stage == "thinker" and self.thinker is not None:
hidden_states = self.thinker(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs
)
return OmniOutput(
text_hidden_states=hidden_states,
multimodal_outputs=None
)

elif self.model_stage == "code2wav" and self.code2wav is not None:
# Input to Code2Wav stage are the Audio Tokens generated by Thinker
audio_tokens = input_ids
waveform = self.code2wav(audio_tokens)
return OmniOutput(
Comment on lines +149 to +153
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge OmniOutput returned without multimodal flag

The code2wav forward path returns an OmniOutput (lines 149‑153) but the class never sets have_multimodal_outputs. GPUGenerationModelRunner calls OmniGPUModelRunner.extract_multimodal_outputs (vllm_omni/worker/gpu_model_runner.py:314-331), which only accepts OmniOutput when the model advertises have_multimodal_outputs; otherwise it raises ValueError("Invalid hidden states type"). As soon as the code2wav stage runs via the generation worker, this missing flag causes a hard crash before any audio is produced.

Useful? React with 👍 / 👎.

text_hidden_states=None,
multimodal_outputs={"model_outputs": waveform}
)

raise RuntimeError("Model stage not initialized correctly")

def compute_logits(self, hidden_states: Union[torch.Tensor, OmniOutput], sampling_metadata: Optional[Any] = None) -> Optional[torch.Tensor]:
if isinstance(hidden_states, OmniOutput):

Check failure on line 161 in vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py:161:121: E501 Line too long (132 > 120)
hidden_states = hidden_states.text_hidden_states
if self.model_stage == "thinker" and self.thinker is not None:
return self.thinker.compute_logits(hidden_states, sampling_metadata)
return None

def sample(self, logits: torch.Tensor, sampling_metadata: Any) -> Any:
if self.model_stage == "thinker" and self.thinker is not None:
return self.thinker.language_model.sample(logits, sampling_metadata)
return None

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
loaded_weights = set()
thinker_weights = []

# If we are in code2wav stage, AutoModel loaded the weights already.
if self.model_stage == "code2wav":
# Consume generator to prevent warnings about unconsumed weights if applicable
for _ in weights: pass
return set()

Check failure on line 180 in vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py:180:29: E701 Multiple statements on one line (colon)

# Filter weights for Thinker stage
for k, v in weights:
if k.startswith("thinker.") or k.startswith("language_model.") or k.startswith("vision_encoder.") or k.startswith("audio_encoder."):
thinker_weights.append((k, v))

Check failure on line 185 in vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py:185:121: E501 Line too long (144 > 120)

if self.thinker:
loaded = self.thinker.load_weights(thinker_weights)
loaded_weights.update(add_prefix_to_loaded_weights(loaded, "thinker"))

return loaded_weights
Loading
Loading