diff --git a/pyproject.toml b/pyproject.toml index 79d15672de2..77853a532e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ "soundfile>=0.13.1", "cache-dit==1.1.8", "tqdm>=4.66.0", + "mamba-ssm>=1.2.0", + "causal-conv1d>=1.2.0", + "decord>=0.6.0", # "vllm==0.12.0", # TODO: fix the entrypoints overwrite problem ] diff --git a/vllm_omni/model_executor/models/__init__.py b/vllm_omni/model_executor/models/__init__.py index 0b2629b4a55..5611ade607a 100644 --- a/vllm_omni/model_executor/models/__init__.py +++ b/vllm_omni/model_executor/models/__init__.py @@ -1,4 +1,8 @@ from .qwen3_omni import Qwen3OmniMoeForConditionalGeneration +from .hyperclovax_seed_omni import HyperCLOVAXSeedOmniForConditionalGeneration from .registry import OmniModelRegistry # noqa: F401 -__all__ = ["Qwen3OmniMoeForConditionalGeneration"] +__all__ = [ + "Qwen3OmniMoeForConditionalGeneration", + "HyperCLOVAXSeedOmniForConditionalGeneration", +] diff --git a/vllm_omni/model_executor/models/hyperclovax_seed_omni/__init__.py b/vllm_omni/model_executor/models/hyperclovax_seed_omni/__init__.py new file mode 100644 index 00000000000..5867728f5f0 --- /dev/null +++ b/vllm_omni/model_executor/models/hyperclovax_seed_omni/__init__.py @@ -0,0 +1,3 @@ +from .hyperclovax_seed_omni import HyperCLOVAXSeedOmniForConditionalGeneration + +__all__ = ["HyperCLOVAXSeedOmniForConditionalGeneration"] diff --git a/vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py b/vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py new file mode 100644 index 00000000000..cb14ec2210e --- /dev/null +++ b/vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni.py @@ -0,0 +1,191 @@ +import torch +import torch.nn as nn +from typing import Optional, Union, Dict, Any, Iterable, Set, Tuple + +from transformers import AutoModel, AutoConfig +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP +from vllm.model_executor.models.utils import init_vllm_registered_model, maybe_prefix, add_prefix_to_loaded_weights +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors +from vllm.v1.sample.sampler import Sampler + +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.model_executor.models.hyperclovax_seed_omni.hyperclovax_seed_omni_thinker import ( + HyperCLOVAXSeedOmniThinkerForConditionalGeneration, + HyperCLOVAXSeedOmniThinkerMultiModalProcessor +) + +logger = init_logger(__name__) + +class HyperCLOVAXSeedOmniAudioDecoder(nn.Module): + """ + Wrapper for CosyVoice2 Audio Decoder. + Loads the actual model implementation from the Hugging Face repository using trust_remote_code. + """ + def __init__(self, config, model_path: Optional[str] = None): + super().__init__() + self.config = config + + # Load the actual CosyVoice2 model from the HF repository + # This executes the remote code (cosyvoice.py) provided by the model authors. + try: + path_to_load = model_path if model_path else str(getattr(config, "_name_or_path", "")) + logger.info(f"Loading CosyVoice2 Audio Decoder from {path_to_load}...") + + # Use AutoModel with trust_remote_code=True to load the custom model class + self.model = AutoModel.from_pretrained( + path_to_load, + config=config, + trust_remote_code=True + ) + logger.info("Successfully loaded CosyVoice2 Audio Decoder.") + + except Exception as e: + logger.error(f"Failed to load CosyVoice2 model via AutoModel: {e}") + logger.warning("Falling back to placeholder for compilation checks (Runtime will fail if not fixed).") + self.model = None + + def forward(self, audio_tokens: torch.Tensor) -> torch.Tensor: + """ + Args: + audio_tokens: (Batch, Seq_Len) tensor of discrete audio tokens + Returns: + waveform: (Batch, Audio_Len) tensor of generated audio waveform + """ + if self.model is None: + # Runtime fail-safe if model loading failed + raise RuntimeError("CosyVoice2 model failed to load. Cannot generate audio.") + + # Call the actual model's inference/decode method. + # The method name depends on the specific implementation in cosyvoice.py. + # Common patterns: model.decode(), model.inference(), or forward() + + # Assuming standard forward or a specific decode method exposed by the remote code + # We try to detect the method or default to forward + if hasattr(self.model, "decode"): + return self.model.decode(audio_tokens) + elif hasattr(self.model, "inference"): + return self.model.inference(audio_tokens) + else: + # Default forward pass + return self.model(audio_tokens) + + def load_weights(self, weights): + # AutoModel handles weight loading during init, so we might not need manual loading here + # unless vLLM requires it. For external models loaded via AutoModel, we return empty set + # to tell vLLM "we handled it". + return set() + +@MULTIMODAL_REGISTRY.register_processor( + HyperCLOVAXSeedOmniThinkerMultiModalProcessor, + info=None, + dummy_inputs=None, +) +class HyperCLOVAXSeedOmniForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP +): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.config = vllm_config.model_config.hf_config + self.prefix = prefix + + self.model_stage = vllm_config.model_config.model_stage + self.model_path = vllm_config.model_config.model # Original model path + + self.thinker: Optional[HyperCLOVAXSeedOmniThinkerForConditionalGeneration] = None + self.code2wav: Optional[HyperCLOVAXSeedOmniAudioDecoder] = None + self.model: Optional[nn.Module] = None + + if self.model_stage == "thinker": + # Initialize thinker model (LLM + Encoders) + self.thinker = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "thinker"), + hf_config=self.config, + architectures=["HyperCLOVAXSeedOmniThinkerForConditionalGeneration"], + ) + self.model = self.thinker + + elif self.model_stage == "code2wav": + # Initialize audio decoder (CosyVoice2 wrapper) + # We pass the full config and the model path to allow loading the remote code + + # Check if there's a specific audio decoder config section + audio_config = getattr(self.config, "audio_decoder_config", self.config) + + self.code2wav = HyperCLOVAXSeedOmniAudioDecoder(audio_config, model_path=self.model_path) + self.model = self.code2wav + + else: + raise ValueError(f"Invalid model stage: {self.model_stage}. Supported: 'thinker', 'code2wav'") + + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors, OmniOutput]: + + if self.model_stage == "thinker" and self.thinker is not None: + hidden_states = self.thinker( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs + ) + return OmniOutput( + text_hidden_states=hidden_states, + multimodal_outputs=None + ) + + elif self.model_stage == "code2wav" and self.code2wav is not None: + # Input to Code2Wav stage are the Audio Tokens generated by Thinker + audio_tokens = input_ids + waveform = self.code2wav(audio_tokens) + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"model_outputs": waveform} + ) + + raise RuntimeError("Model stage not initialized correctly") + + def compute_logits(self, hidden_states: Union[torch.Tensor, OmniOutput], sampling_metadata: Optional[Any] = None) -> Optional[torch.Tensor]: + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + if self.model_stage == "thinker" and self.thinker is not None: + return self.thinker.compute_logits(hidden_states, sampling_metadata) + return None + + def sample(self, logits: torch.Tensor, sampling_metadata: Any) -> Any: + if self.model_stage == "thinker" and self.thinker is not None: + return self.thinker.language_model.sample(logits, sampling_metadata) + return None + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + loaded_weights = set() + thinker_weights = [] + + # If we are in code2wav stage, AutoModel loaded the weights already. + if self.model_stage == "code2wav": + # Consume generator to prevent warnings about unconsumed weights if applicable + for _ in weights: pass + return set() + + # Filter weights for Thinker stage + for k, v in weights: + if k.startswith("thinker.") or k.startswith("language_model.") or k.startswith("vision_encoder.") or k.startswith("audio_encoder."): + thinker_weights.append((k, v)) + + if self.thinker: + loaded = self.thinker.load_weights(thinker_weights) + loaded_weights.update(add_prefix_to_loaded_weights(loaded, "thinker")) + + return loaded_weights diff --git a/vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni_thinker.py b/vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni_thinker.py new file mode 100644 index 00000000000..9fe7d6fda45 --- /dev/null +++ b/vllm_omni/model_executor/models/hyperclovax_seed_omni/hyperclovax_seed_omni_thinker.py @@ -0,0 +1,190 @@ +from typing import Any, List, Optional, Tuple, Union, Iterable, Set + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoConfig + +from vllm.config import VllmConfig +from vllm.model_executor.models.interfaces import ( + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + init_vllm_registered_model, + maybe_prefix, + extract_layer_inputs, + merge_multimodal_embeddings, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.logger import init_logger +from .mamba_mia import MambaMiaCompressor + +logger = init_logger(__name__) + +class HyperCLOVAXSeedOmniThinkerMultiModalProcessor: + def __init__(self, *args, **kwargs): + pass + + def apply(self, *args, **kwargs): + return args[0] + +@MULTIMODAL_REGISTRY.register_processor( + HyperCLOVAXSeedOmniThinkerMultiModalProcessor, + info=None, + dummy_inputs=None, +) +class HyperCLOVAXSeedOmniThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, +): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.config = vllm_config.model_config.hf_config + + # 1. Initialize LLM backbone (Llama-based 8B) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=self.config, + architectures=["LlamaForCausalLM"], + ) + + # 2. Initialize Encoders using HF AutoModel + # Vision Encoder + if hasattr(self.config, "vision_config"): + try: + self.vision_encoder = AutoModel.from_config(self.config.vision_config, trust_remote_code=True) + logger.info("Successfully loaded Vision Encoder via AutoModel (remote code).") + except Exception as e: + logger.error(f"Failed to load Vision Encoder: {e}") + raise RuntimeError("Could not load Vision Encoder.") from e + else: + self.vision_encoder = None + + # Audio Encoder + if hasattr(self.config, "audio_config"): + try: + self.audio_encoder = AutoModel.from_config(self.config.audio_config, trust_remote_code=True) + logger.info("Successfully loaded Audio Encoder via AutoModel (remote code).") + except Exception as e: + logger.error(f"Failed to load Audio Encoder: {e}") + raise RuntimeError("Could not load Audio Encoder.") from e + else: + self.audio_encoder = None + + # 3. MambaMia Compression (Video) + if getattr(self.config, "use_mamba_mia", False): + logger.info("Initializing MambaMia Video Compressor") + self.mamba_mia = MambaMiaCompressor(self.config) + else: + self.mamba_mia = None + + self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + return self.language_model.model.embed_tokens(input_ids) + + def _process_vision_input(self, pixel_values: torch.Tensor) -> torch.Tensor: + # Forward pass through Vision Encoder + if self.vision_encoder is None: + raise ValueError("Vision inputs provided but Vision Encoder is not initialized.") + + vision_outputs = self.vision_encoder(pixel_values) + + # Handle different output formats (hidden_states or direct tensor) + if hasattr(vision_outputs, "last_hidden_state"): + image_features = vision_outputs.last_hidden_state + else: + image_features = vision_outputs + + # Apply MambaMia compression if enabled (typically for video) + # Assuming pixel_values shape hints at video (Batch, Num_Frames, C, H, W) vs Image + if self.mamba_mia is not None: + # Check if input is video-like or if we should always apply it + # For now, apply if MambaMia is initialized + image_features = self.mamba_mia(image_features) + + return image_features + + def _process_audio_input(self, audio_values: torch.Tensor) -> torch.Tensor: + if self.audio_encoder is None: + raise ValueError("Audio inputs provided but Audio Encoder is not initialized.") + + audio_outputs = self.audio_encoder(audio_values) + if hasattr(audio_outputs, "last_hidden_state"): + return audio_outputs.last_hidden_state + return audio_outputs + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, torch.Tensor]: + + # 1. Process Multimodal Inputs if present in kwargs + pixel_values = kwargs.pop("pixel_values", None) + audio_values = kwargs.pop("audio_values", None) + + vision_embeddings = None + audio_embeddings = None + + if pixel_values is not None: + vision_embeddings = self._process_vision_input(pixel_values) + + if audio_values is not None: + audio_embeddings = self._process_audio_input(audio_values) + + # 2. Embed Text Inputs if inputs_embeds not provided + if inputs_embeds is None: + inputs_embeds = self.embed_input_ids(input_ids) + + # 3. Merge Embeddings + # If we have multimodal embeddings, we need to merge them into inputs_embeds + # vLLM provides `merge_multimodal_embeddings` utility, but specific logic depends on model + # Here we assume a simple replacement or concatenation strategy supported by vLLM + # In a real scenario, this would use `merge_multimodal_embeddings` with the model's placeholder tokens. + + if vision_embeddings is not None or audio_embeddings is not None: + # Create a dictionary of embeddings to merge + mm_embeddings_dict = {} + if vision_embeddings is not None: + mm_embeddings_dict["image"] = vision_embeddings + if audio_embeddings is not None: + mm_embeddings_dict["audio"] = audio_embeddings + + # Use vLLM's utility to merge (this requires input_ids to have placeholder tokens) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, mm_embeddings_dict, self.config + ) + + # 4. Forward through LLM + return self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: Optional[Any] = None, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/model_executor/models/hyperclovax_seed_omni/mamba_mia.py b/vllm_omni/model_executor/models/hyperclovax_seed_omni/mamba_mia.py new file mode 100644 index 00000000000..775d3f59e0d --- /dev/null +++ b/vllm_omni/model_executor/models/hyperclovax_seed_omni/mamba_mia.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from typing import Optional, Tuple + +# Placeholder for Mamba/SSM implementation +# In a real environment, this would import from `mamba_ssm` or vLLM's internal mamba kernels +try: + from mamba_ssm import Mamba +except ImportError: + # Fallback dummy Mamba class if package is missing + class Mamba(nn.Module): + def __init__(self, d_model, d_state=16, d_conv=4, expand=2): + super().__init__() + self.d_model = d_model + self.proj = nn.Linear(d_model, d_model) + def forward(self, x): + return self.proj(x) + +class MambaMiaCompressor(nn.Module): + """ + MambaMia: State-Space-Model-Based Compression for Video Understanding. + Compresses video frame tokens using Bidirectional Mamba blocks and weighted pooling. + """ + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.compression_rate = getattr(config, "mamba_mia_compression_rate", 4) # Example default + + # Learnable Queries for pooling + # "periodically inserted learned queries" + self.num_queries = getattr(config, "num_queries", 64) + self.query_embed = nn.Parameter(torch.randn(1, self.num_queries, self.hidden_size)) + + # Bidirectional Mamba Block + # "bidirectional state-space-based block equipped with a gated skip connection" + self.mamba_fwd = Mamba( + d_model=self.hidden_size, + d_state=getattr(config, "mamba_d_state", 16), + d_conv=getattr(config, "mamba_d_conv", 4), + expand=getattr(config, "mamba_expand", 2), + ) + self.mamba_bwd = Mamba( + d_model=self.hidden_size, + d_state=getattr(config, "mamba_d_state", 16), + d_conv=getattr(config, "mamba_d_conv", 4), + expand=getattr(config, "mamba_expand", 2), + ) + + self.norm = nn.LayerNorm(self.hidden_size) + self.proj_out = nn.Linear(self.hidden_size * 2, self.hidden_size) # Fuse fwd/bwd + + # Gating mechanism for skip connection + self.gate = nn.Linear(self.hidden_size * 2, self.hidden_size) + self.sigmoid = nn.Sigmoid() + + def forward(self, video_features: torch.Tensor) -> torch.Tensor: + """ + Args: + video_features: (Batch, Num_Frames * Tokens_Per_Frame, Hidden_Size) + or (Batch, Seq_Len, Hidden_Size) + """ + B, L, D = video_features.shape + + # 1. Bidirectional Mamba Processing + # Forward pass + x_fwd = self.mamba_fwd(video_features) + + # Backward pass (flip sequence) + x_bwd = self.mamba_bwd(torch.flip(video_features, dims=[1])) + x_bwd = torch.flip(x_bwd, dims=[1]) + + # Fuse directions + x_processed = torch.cat([x_fwd, x_bwd], dim=-1) # (B, L, 2*D) + + # Gated Skip Connection logic (Simplified interpretation of MambaMia) + # Assuming we project back to D and add residual + gate_score = self.sigmoid(self.gate(x_processed)) + x_fused = self.proj_out(x_processed) * gate_score + + x_out = self.norm(x_fused + video_features) + + # 2. Learnable Weighted-Average Pooling / Query-based Downsampling + # MambaMia uses queries to pool information from the processed sequence + # We perform cross-attention or simple weighted pooling based on queries + + # (Batch, Num_Queries, D) + queries = self.query_embed.expand(B, -1, -1) + + # Simple attention-based pooling for compression + # Q = Queries, K=V = x_out + # (B, Q_len, D) x (B, D, L) -> (B, Q_len, L) + attn_logits = torch.bmm(queries, x_out.transpose(1, 2)) / (D ** 0.5) + attn_weights = torch.softmax(attn_logits, dim=-1) + + compressed_features = torch.bmm(attn_weights, x_out) # (B, Num_Queries, D) + + return compressed_features diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 56bceae41ab..858604327d0 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -48,6 +48,16 @@ "qwen3_omni_code2wav", "Qwen3OmniMoeCode2Wav", ), + "HyperCLOVAXSeedOmniForConditionalGeneration": ( + "hyperclovax_seed_omni", + "hyperclovax_seed_omni", + "HyperCLOVAXSeedOmniForConditionalGeneration", + ), + "HyperCLOVAXSeedOmniThinkerForConditionalGeneration": ( + "hyperclovax_seed_omni", + "hyperclovax_seed_omni_thinker", + "HyperCLOVAXSeedOmniThinkerForConditionalGeneration", + ), } _VLLM_OMNI_MODELS = { diff --git a/vllm_omni/model_executor/stage_configs/hyperclovax_seed_omni.yaml b/vllm_omni/model_executor/stage_configs/hyperclovax_seed_omni.yaml new file mode 100644 index 00000000000..101cb58c82a --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hyperclovax_seed_omni.yaml @@ -0,0 +1,61 @@ +# Stage config for HyperCLOVAX-SEED-Omni-8B + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: HyperCLOVAXSeedOmniForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32768 + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + max_tokens: 2048 + detokenize: True + + - stage_id: 1 + stage_type: llm + runtime: + process: true + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: HyperCLOVAXSeedOmniForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.15 + enforce_eager: true + trust_remote_code: true + engine_output_type: audio + engine_input_source: [0] + # Input processor to convert Thinker output (tokens) to Code2Wav input + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni.thinker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_input_processors/hyperclovax_seed_omni.py b/vllm_omni/model_executor/stage_input_processors/hyperclovax_seed_omni.py new file mode 100644 index 00000000000..f30350e804a --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/hyperclovax_seed_omni.py @@ -0,0 +1,41 @@ +from typing import List, Union +import torch +from vllm_omni.inputs.data import OmniTokensPrompt + +def thinker2code2wav( + stage_list, + engine_input_source, + prompt=None, + requires_multimodal_data: bool = False, +): + """ + Process output from Thinker (LLM) stage and prepare input for Code2Wav stage. + Assumes Thinker output contains audio tokens that need to be decoded. + """ + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + source_stage_id = engine_input_source[0] + + thinker_outputs = stage_list[source_stage_id].engine_outputs + code2wav_inputs = [] + + # Iterate over batch + for i, thinker_output in enumerate(thinker_outputs): + output = thinker_output.outputs[0] + # Get generated tokens + token_ids = output.token_ids + + # Filter for audio tokens if they are mixed with text. + # For now, we assume ALL generated tokens are passed to code2wav + # or that there's a specific range/mask. + # In a real implementation, we'd filter based on token IDs (e.g. > specific ID). + audio_tokens = token_ids + + code2wav_inputs.append( + OmniTokensPrompt( + prompt_token_ids=audio_tokens, + multi_modal_data=None # Audio generation usually doesn't need MM input again + ) + ) + + return code2wav_inputs