diff --git a/examples/conversion/hf_to_megatron_generate_omni_lm.py b/examples/conversion/hf_to_megatron_generate_omni_lm.py new file mode 100644 index 0000000000..510f598863 --- /dev/null +++ b/examples/conversion/hf_to_megatron_generate_omni_lm.py @@ -0,0 +1,482 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Omni-Language Model Generation Script for Qwen3-Omni. + +This script demonstrates how to use Qwen3-Omni models with Megatron-Bridge +for video understanding tasks (with optional audio from video). + +Requirements: + pip install qwen-omni-utils[decord] + +Example: + + uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 examples/conversion/hf_to_megatron_generate_omni_lm.py \ + --hf_model_path=Qwen/Qwen2.5-Omni-7B \ + --video_url="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/cookbook/audio_visual.mp4" \ + --prompt="What was the first sentence the boy said when he met the girl?" \ + --use_audio_in_video \ + --tp 2 \ + --trust_remote_code +""" + +import argparse +from typing import Optional + +import torch +import torch.distributed as dist +from megatron.core import parallel_state +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from transformers import AutoProcessor, AutoTokenizer + +from megatron.bridge import AutoBridge +from megatron.bridge.models.hf_pretrained.utils import is_safe_repo +from megatron.bridge.utils.common_utils import get_last_rank, print_rank_0 + + +# Try to import qwen_omni_utils for video/audio processing +try: + from qwen_omni_utils import process_mm_info + + HAS_QWEN_OMNI_UTILS = True +except ImportError: + process_mm_info = None + HAS_QWEN_OMNI_UTILS = False + + +class SingleBatchIterator: + """Iterator that yields a single batch of data for omni-language generation. + Required by the forward_backward_func function. + + This class creates an iterator that yields exactly one batch containing + input tokens, attention mask, and optional video/audio inputs, + then raises StopIteration. Used for single-step inference in the forward pass. + """ + + def __init__( + self, + input_ids, + attention_mask, + pixel_values_videos=None, + video_grid_thw=None, + video_second_per_grid=None, + input_features=None, + feature_attention_mask=None, + use_audio_in_video=None, + ): + self.batch = dict( + tokens=input_ids, + attention_mask=attention_mask, + ) + + # Add video inputs if provided + if pixel_values_videos is not None: + self.batch["pixel_values_videos"] = pixel_values_videos + if video_grid_thw is not None: + self.batch["video_grid_thw"] = video_grid_thw + if video_second_per_grid is not None: + self.batch["video_second_per_grid"] = video_second_per_grid + + # Add audio inputs if provided + if input_features is not None: + self.batch["input_features"] = input_features + if feature_attention_mask is not None: + self.batch["feature_attention_mask"] = feature_attention_mask + + if use_audio_in_video is not None: + self.batch["use_audio_in_video"] = use_audio_in_video + + self._yielded = False + + def __iter__(self): + return self + + def __next__(self): + if self._yielded: + raise StopIteration + self._yielded = True + return self.batch + + +def omni_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: + """Forward step function for omni-language generation. + Required by the forward_backward_func function. + + Extracts a batch from the data iterator and runs the model forward pass + with the provided input tokens, attention mask, video inputs, and audio inputs. + Position IDs are computed internally by the model using multimodal RoPE. + + Args: + data_iterator: Iterator providing batches of input data + model: The Megatron model to run forward pass on + **kwargs: Additional keyword arguments (unused) + + Returns: + Tuple of (model_output, loss_function) + """ + batch = next(data_iterator) + forward_args = { + "input_ids": batch["tokens"], + "position_ids": None, # Let model compute mrope position_ids internally + "attention_mask": batch.get("attention_mask", None), + } + + # Add video inputs if present + if "pixel_values_videos" in batch: + forward_args["pixel_values_videos"] = batch["pixel_values_videos"] + if "video_grid_thw" in batch: + forward_args["video_grid_thw"] = batch["video_grid_thw"] + if "video_second_per_grid" in batch: + forward_args["video_second_per_grid"] = batch["video_second_per_grid"] + + # Add audio inputs if present + if "input_features" in batch: + forward_args["input_features"] = batch["input_features"] + if "feature_attention_mask" in batch: + forward_args["feature_attention_mask"] = batch["feature_attention_mask"] + + if "use_audio_in_video" in batch: + forward_args["use_audio_in_video"] = batch["use_audio_in_video"] + + def loss_func(x, **kwargs): + return x + + model_output = model(**forward_args) + if isinstance(model_output, tuple): + output_tensor, _ = model_output + else: + output_tensor = model_output + + return output_tensor, loss_func + + +def process_omni_inputs(processor, video_path: Optional[str], prompt: str, use_audio_in_video: bool): + """Process video/audio inputs for omni-language model. + + Args: + processor: AutoProcessor for the omni-language model + video_path: Path or URL to the video file (optional) + prompt: Text prompt + use_audio_in_video: Whether to use audio track from the video + + Returns: + Dict containing processed inputs and messages + """ + if video_path: + # Create messages with video and text for Qwen3-Omni format + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + {"type": "text", "text": prompt}, + ], + } + ] + + # Apply chat template + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + if not HAS_QWEN_OMNI_UTILS: + raise ImportError( + "qwen_omni_utils is required for video processing. " + "Please install it: pip install qwen-omni-utils[decord]" + ) + + # Extract audios, images, videos from messages + audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video) + + # Process inputs with video (and optionally audio) + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=use_audio_in_video, + ) + + return { + "input_ids": inputs.input_ids, + "attention_mask": inputs.attention_mask, + "pixel_values_videos": getattr(inputs, "pixel_values_videos", None), + "video_grid_thw": getattr(inputs, "video_grid_thw", None), + "video_second_per_grid": getattr(inputs, "video_second_per_grid", None), + "input_features": getattr(inputs, "input_features", None), + "feature_attention_mask": getattr(inputs, "feature_attention_mask", None), + "messages": messages, + } + else: + # Text-only processing + messages = [{"role": "user", "content": prompt}] + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=text, return_tensors="pt") + return { + "input_ids": inputs.input_ids, + "attention_mask": inputs.attention_mask, + "pixel_values_videos": None, + "video_grid_thw": None, + "video_second_per_grid": None, + "input_features": None, + "feature_attention_mask": None, + "messages": messages, + } + + +def main(args) -> None: + """Main function for omni-language generation from HuggingFace models. + + Loads a Qwen3-Omni model either from HuggingFace (with optional conversion to Megatron) + or directly from a Megatron checkpoint, then performs greedy generation + using the provided prompt and optional video input. + + Args: + args: Parsed command line arguments containing model paths, prompt, + video path, parallelism settings, and generation parameters + """ + tp = args.tp + pp = args.pp + ep = args.ep + etp = args.etp + + # Choose loading method based on arguments + if args.megatron_model_path: + # Load from Megatron checkpoint + print_rank_0(f"Loading Megatron model from: {args.megatron_model_path}") + + bridge = AutoBridge.from_hf_pretrained(args.hf_model_path) + + # Initialize model parallel before loading + model_provider = bridge.to_megatron_provider(load_weights=False) + model_provider.tensor_model_parallel_size = tp + model_provider.pipeline_model_parallel_size = pp + model_provider.expert_model_parallel_size = ep + model_provider.expert_tensor_parallel_size = etp + model_provider.pipeline_dtype = torch.bfloat16 + model_provider.finalize() + model_provider.initialize_model_parallel(seed=0) + + # Load the Megatron model directly + model = bridge.load_megatron_model( + args.megatron_model_path, + mp_overrides={ + "tensor_model_parallel_size": tp, + "pipeline_model_parallel_size": pp, + "expert_model_parallel_size": ep, + "expert_tensor_parallel_size": etp, + "pipeline_dtype": torch.bfloat16, + }, + wrap_with_ddp=False, + ) + + else: + # Load from HuggingFace and convert to Megatron + print_rank_0(f"Loading HuggingFace model from: {args.hf_model_path}") + bridge = AutoBridge.from_hf_pretrained(args.hf_model_path) + model_provider = bridge.to_megatron_provider(load_weights=True) + model_provider.tensor_model_parallel_size = tp + model_provider.pipeline_model_parallel_size = pp + model_provider.expert_model_parallel_size = ep + model_provider.expert_tensor_parallel_size = etp + model_provider.pipeline_dtype = torch.bfloat16 + model_provider.finalize() + model_provider.initialize_model_parallel(seed=0) + model = model_provider.provide_distributed_model(wrap_with_ddp=False) + + model = [m.cuda() for m in model] + for m in model: + m.eval() + + # Set grad_scale_func to None on the model's config for inference + for m in model: + if hasattr(m, "config"): + m.config.grad_scale_func = None + + # Initialize tokenizer and processor + tokenizer = AutoTokenizer.from_pretrained( + args.hf_model_path, + trust_remote_code=is_safe_repo( + trust_remote_code=args.trust_remote_code, + hf_path=args.hf_model_path, + ), + ) + processor = AutoProcessor.from_pretrained( + args.hf_model_path, + trust_remote_code=is_safe_repo( + trust_remote_code=args.trust_remote_code, + hf_path=args.hf_model_path, + ), + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Determine video path (URL or file) + video_path = args.video_url or args.video_path + + # Process inputs (text and video/audio if provided) + prompt = args.prompt + processed = process_omni_inputs(processor, video_path, prompt, args.use_audio_in_video) + + input_ids = processed["input_ids"] + attention_mask = processed["attention_mask"] + pixel_values_videos = processed["pixel_values_videos"] + video_grid_thw = processed["video_grid_thw"] + video_second_per_grid = processed["video_second_per_grid"] + input_features = processed["input_features"] + feature_attention_mask = processed["feature_attention_mask"] + + # Move to GPU + input_ids = input_ids.cuda() + attention_mask = attention_mask.cuda() + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.cuda() + if video_grid_thw is not None: + video_grid_thw = video_grid_thw.cuda() + if input_features is not None: + input_features = input_features.cuda() + if feature_attention_mask is not None: + feature_attention_mask = feature_attention_mask.cuda() + + generated_ids = input_ids.clone() + + stop_tokens = [tokenizer.eos_token_id] + + use_audio_in_video = args.use_audio_in_video if video_path else None + + # Greedy generation loop + for step in range(args.max_new_tokens): + with torch.no_grad(): + print_rank_0(f"Generation step {step}") + + fwd_bwd_function = get_forward_backward_func() + + # Pass all multimodal inputs for every step + iterator = SingleBatchIterator( + input_ids, + attention_mask, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + video_second_per_grid=video_second_per_grid, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + use_audio_in_video=use_audio_in_video, + ) + + output = fwd_bwd_function( + forward_step_func=omni_forward_step, + data_iterator=iterator, + model=model, + num_microbatches=1, + forward_only=True, + seq_length=input_ids.size(1), + micro_batch_size=1, + collect_non_loss_data=True, + ) + if isinstance(output, list) and len(output) > 0: + output = output[0] + + if parallel_state.is_pipeline_last_stage(): + world_size = parallel_state.get_tensor_model_parallel_world_size() + gathered_tensors = [torch.zeros_like(output) for _ in range(world_size)] + # All-gather operation + dist.all_gather(gathered_tensors, output, group=parallel_state.get_tensor_model_parallel_group()) + # Concatenate along last dimension (dim=2) + output = torch.cat(gathered_tensors, dim=2) + next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) + + # Debug: print token information + if step < 5: # Only for first few iterations + print_rank_0(f"Step {step}: output shape={output.shape}, var={output.var():.4f}") + logits = output[0, -1, :] + top5_vals, top5_ids = torch.topk(logits, 5) + top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids] + print_rank_0(f"Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}") + print_rank_0( + f"Selected: '{tokenizer.decode([next_token_ids.item()])}' (id={next_token_ids.item()})" + ) + else: + next_token_ids = torch.ones((1, 1), device=generated_ids.device, dtype=generated_ids.dtype) + + torch.distributed.broadcast(next_token_ids, get_last_rank()) + generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) + + input_ids = generated_ids + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + + # If the generated token is the end of sequence token, stop generating + if next_token_ids.item() in stop_tokens: + break + + # Decode the generated sequence + generated_text = tokenizer.decode(list(generated_ids[0]), skip_special_tokens=True) + print_rank_0("======== GENERATED TEXT OUTPUT ========") + if video_path: + print_rank_0(f"Video: {video_path}") + print_rank_0(f"Use audio in video: {args.use_audio_in_video}") + print_rank_0(f"Prompt: {prompt}") + print_rank_0(f"Generated: {generated_text}") + print_rank_0("=======================================") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Omni-Language Generation from HuggingFace Qwen3-Omni Models") + parser.add_argument( + "--hf_model_path", + type=str, + required=True, + help="Path to the HuggingFace Qwen3-Omni model.", + ) + parser.add_argument( + "--prompt", + type=str, + default="What was the first sentence the boy said when he met the girl?", + help="Input prompt for omni-language generation.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=50, + help="Maximum number of new tokens to generate.", + ) + parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallelism size") + parser.add_argument("--ep", type=int, default=1, help="Expert parallelism size") + parser.add_argument("--etp", type=int, default=1, help="Expert tensor parallelism size") + parser.add_argument("--megatron_model_path", type=str, default=None, help="Path to the Megatron model checkpoint") + parser.add_argument( + "--video_path", + type=str, + default=None, + help="Local path to the video file (optional).", + ) + parser.add_argument( + "--video_url", + type=str, + default=None, + help="URL to the video file (optional).", + ) + parser.add_argument( + "--use_audio_in_video", + action="store_true", + help="Whether to use audio track from the video for understanding.", + ) + parser.add_argument("--trust_remote_code", action="store_true", help="if trust_remote_code") + args = parser.parse_args() + + main(args) + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() diff --git a/src/megatron/bridge/models/__init__.py b/src/megatron/bridge/models/__init__.py index 8cb0e8ffaa..6c5dfa4281 100644 --- a/src/megatron/bridge/models/__init__.py +++ b/src/megatron/bridge/models/__init__.py @@ -178,6 +178,11 @@ Qwen25ModelProvider72B, Qwen25ModelProvider500M, ) +from megatron.bridge.models.qwen_omni import ( + Qwen25OmniBridge, + Qwen25OmniModel, + Qwen25OmniModelProvider, +) from megatron.bridge.models.qwen_vl import ( Qwen25VLBridge, Qwen25VLModel, @@ -350,6 +355,10 @@ "NemotronVLBridge", "NemotronNano12Bv2Provider", "NemotronNano12Bv2VLModelProvider", + # Omni Models + "Qwen25OmniModel", + "Qwen25OmniBridge", + "Qwen25OmniModelProvider", "SarvamMLABridge", "SarvamMoEBridge", ] diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 3e0ac4e02a..34b7923a7d 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -46,8 +46,17 @@ "ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2", + "Qwen2_5OmniModel", ) +# Mapping from non-standard HF architecture names to their actual transformers class names. +# Some HF model configs report architecture names that don't follow the standard +# 'ForCausalLM'/'ForConditionalGeneration' convention and don't directly map to a +# transformers class. This dict resolves those aliases. +HF_ARCHITECTURE_ALIASES: dict[str, str] = { + "Qwen2_5OmniModel": "Qwen2_5OmniForConditionalGeneration", +} + # Preformatted display string for error/help messages SUPPORTED_HF_ARCHITECTURES_DISPLAY = " or ".join(f"'{s}'" for s in SUPPORTED_HF_ARCHITECTURES) @@ -1111,11 +1120,14 @@ def _causal_lm_architecture(self): # For auto_map models, return the class name as a string return cls_name + # Resolve non-standard architecture names via alias mapping + resolved_arch = HF_ARCHITECTURE_ALIASES.get(causal_lm_arch, causal_lm_arch) + try: - return getattr(transformers, causal_lm_arch) + return getattr(transformers, resolved_arch) except AttributeError: raise ValueError( - f"\nāœ— Architecture class '{causal_lm_arch}' not found in transformers\n\n" + f"\nāœ— Architecture class '{resolved_arch}' not found in transformers\n\n" f"This could mean:\n" f"1. The model requires a newer version of transformers\n" f"2. The model uses a custom modeling file not in the standard library\n" @@ -1152,8 +1164,10 @@ def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> # For auto_map models, use class-name string arch_key = arch_name else: + # Resolve non-standard architecture names via alias mapping + resolved_arch = HF_ARCHITECTURE_ALIASES.get(architecture, architecture) try: - arch_class = getattr(transformers, architecture) + arch_class = getattr(transformers, resolved_arch) arch_key = arch_class except AttributeError: # Fall back to name-based registration diff --git a/src/megatron/bridge/models/qwen_omni/__init__.py b/src/megatron/bridge/models/qwen_omni/__init__.py new file mode 100644 index 0000000000..4c051916e2 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.model import Qwen25OmniModel +from megatron.bridge.models.qwen_omni.qwen25_omni_bridge import Qwen25OmniBridge +from megatron.bridge.models.qwen_omni.qwen25_omni_provider import Qwen25OmniModelProvider + + +__all__ = [ + "Qwen25OmniBridge", + "Qwen25OmniModel", + "Qwen25OmniModelProvider", +] diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/__init__.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/model.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/model.py new file mode 100644 index 0000000000..87d4278280 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/model.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import InferenceParams +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniTalkerConfig, + Qwen2_5OmniThinkerConfig, + Qwen2_5OmniToken2WavConfig, +) + +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.thinker_model import Qwen25OmniThinkerModel +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.transformer_config import Qwen25OmniTransformerConfig + + +class Qwen25OmniModel(MegatronModule): + """Qwen2.5 Omni Model. + + Top-level wrapper that delegates to Qwen25OmniThinkerModel. + Same pattern as Qwen3OmniMoeModel but simpler (no deepstack, dense LLM). + """ + + def __init__( + self, + language_transformer_config: Qwen25OmniTransformerConfig, + language_transformer_layer_spec: ModuleSpec, + thinker_transformer_config: Qwen2_5OmniThinkerConfig, + talker_transformer_config: Qwen2_5OmniTalkerConfig | None = None, + token2wav_transformer_config: Qwen2_5OmniToken2WavConfig | None = None, + parallel_output: bool = True, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + pg_collection: ProcessGroupCollection | None = None, + ) -> None: + super().__init__(config=language_transformer_config) + + self.thinker = Qwen25OmniThinkerModel( + language_transformer_config, + language_transformer_layer_spec, + thinker_transformer_config, + parallel_output, + pre_process, + post_process, + add_encoder, + add_decoder, + pg_collection, + ) + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + return self.thinker.shared_embedding_or_output_weight() + + def set_input_tensor(self, input_tensor) -> None: + return self.thinker.set_input_tensor(input_tensor) + + def freeze( + self, + freeze_language_model: bool = False, + freeze_vision_model: bool = False, + freeze_audio_model: bool = False, + ): + """Freeze model modules. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_audio_model (bool): Freeze the audio model module. + """ + return self.thinker.freeze( + freeze_language_model, + freeze_vision_model, + freeze_audio_model, + ) + + def forward( + self, + input_ids: torch.Tensor, + input_features=None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + loss_mask: torch.Tensor | None = None, + inference_params: InferenceParams | None = None, + packed_seq_params: PackedSeqParams | None = None, + extra_block_kwargs: dict | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + image_input_mask: torch.Tensor | None = None, + video_input_mask: torch.Tensor | None = None, + feature_attention_mask=None, + audio_feature_lengths=None, + cp_img_num: list[int] | None = None, + images_padded: list[bool] | None = None, + use_audio_in_video=None, + video_second_per_grid=None, + **kwargs, + ) -> torch.Tensor: + return self.thinker( + input_ids=input_ids, + input_features=input_features, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + extra_block_kwargs=extra_block_kwargs, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_input_mask=image_input_mask, + video_input_mask=video_input_mask, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + cp_img_num=cp_img_num, + images_padded=images_padded, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/rope.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/rope.py new file mode 100644 index 0000000000..5ea0c6f376 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/rope.py @@ -0,0 +1,312 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def _get_feat_extract_output_lengths(input_lengths): + """ + Computes the output length of the convolutional layers and the audio encoder + for Qwen2.5-Omni. + + Formula: feat = (input_lengths - 1) // 2 + 1, output = (feat - 2) // 2 + 1 + """ + feat_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (feat_lengths - 2) // 2 + 1 + return output_lengths + + +def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[torch.Tensor], + grid_hs: list[torch.Tensor], + grid_ws: list[torch.Tensor], +): + """Get LLM position IDs for vision tokens (3D: temporal, height, width).""" + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + +def get_chunked_index(token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int) -> list[tuple[int, int]]: + """ + Splits token index list into chunks based on token value ranges. + + Given a list of token indices, returns a list of (start, end) index tuples representing + slices of the list where the token values fall within successive ranges of tokens_per_chunk. + """ + + def _iter(): + i, start_idx = 0, 0 + current_chunk = 1 + while i < len(token_indices): + if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + +def get_rope_index( + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + audio_token_id: int, + vision_start_token_id: int, + audio_start_token_id: int, + position_id_per_seconds: int, + seconds_per_chunk: int = 2, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + use_audio_in_video: bool = False, + audio_seqlens: torch.LongTensor | None = None, + second_per_grids: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Ported from HF Qwen2_5OmniThinkerForConditionalGeneration.get_rope_index as a standalone function. + + Key differences from Qwen3 Omni MoE rope: + - Audio output length: ((audio_seqlens - 1) // 2 + 1 - 2) // 2 + 1 + - Token scanning: searches for image_token_id/video_token_id/audio_token_id directly + - Has seconds_per_chunk for audio-in-video interleaving + - Uses get_chunked_index for audio-in-video chunk interleaving + """ + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is not None: + attention_mask = attention_mask == 1 + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + for i, batch_input_ids in enumerate(total_input_ids): + if attention_mask is not None: + batch_input_ids = batch_input_ids[attention_mask[i]] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere(batch_input_ids == vision_start_token_id).squeeze(1) + vision_tokens = batch_input_ids[vision_start_indices + 1] + audio_nums = torch.sum(batch_input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = batch_input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums + multimodal_nums = image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio = input_tokens.index(audio_token_id, st) + else: + ed_audio = len(input_tokens) + 1 + min_ed = min(ed_image, ed_video, ed_audio) + + # Audio Only + if min_ed == ed_audio: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + + # Image Only + elif min_ed == ed_image: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + + # Video Only (no audio in video) + elif min_ed == ed_video and not use_audio_in_video: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + + # Audio in Video + elif min_ed == ed_video and use_audio_in_video: + text_len = min_ed - st - 2 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + video_llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + + t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + sub_len = 0 + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None + audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None + if video_chunk_index is not None: + sub_len += video_chunk_index[1] - video_chunk_index[0] + llm_pos_ids_list.append(video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]]) + if audio_chunk_index is not None: + sub_len += audio_chunk_index[1] - audio_chunk_index[0] + llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]]) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + + if attention_mask is not None: + position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device) + else: + position_ids[..., i, :] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(batch_input_ids)) + mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=total_input_ids.device) + + return position_ids, mrope_position_deltas + else: + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + attention_mask = attention_mask.to(input_ids.device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + return position_ids, mrope_position_deltas diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/thinker_model.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/thinker_model.py new file mode 100644 index 0000000000..c52b1396a5 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/thinker_model.py @@ -0,0 +1,352 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniThinkerConfig as Qwen2_5OmniThinkerConfigHF, +) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder as Qwen2_5OmniAudioEncoderHF, +) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniVisionEncoder as Qwen2_5OmniVisionEncoderHF, +) + +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.rope import get_rope_index +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.transformer_config import Qwen25OmniTransformerConfig +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.attention import Qwen3VLSelfAttention +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import Qwen3VLGPTModel +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import ( + split_data_cp_rank, +) +from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync + + +class Qwen25OmniThinkerModel(MegatronModule): + """Qwen2.5 Omni Thinker Model. + + Key differences from Qwen3OmniMoeThinkerModel: + - Uses HF vision encoder (Qwen2_5OmniVisionEncoder) directly, not Megatron-native + - Uses HF audio encoder (Qwen2_5OmniAudioEncoder) directly + - No deepstack visual embeddings + - Vision embeddings inserted only at input level + - Dense LLM (Qwen2 architecture), not MoE + """ + + def __init__( + self, + language_transformer_config: Qwen25OmniTransformerConfig, + language_transformer_layer_spec: ModuleSpec, + thinker_transformer_config: Qwen2_5OmniThinkerConfigHF, + parallel_output: bool = True, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + pg_collection: ProcessGroupCollection | None = None, + ) -> None: + super().__init__(config=language_transformer_config) + + language_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention + + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.encoder_hidden_state = None + self.visual = None + self.audio_model = None + self.language_model = None + self.image_token_id = language_transformer_config.image_token_id + self.video_token_id = language_transformer_config.video_token_id + self.audio_token_id = language_transformer_config.audio_token_id + self.vision_start_token_id = language_transformer_config.vision_start_token_id + self.audio_start_token_id = language_transformer_config.audio_start_token_id + self.position_id_per_seconds = language_transformer_config.position_id_per_seconds + self.seconds_per_chunk = language_transformer_config.seconds_per_chunk + + self.square_merge_size = thinker_transformer_config.vision_config.spatial_merge_size**2 + + self.share_embeddings_and_output_weights = False + self.pg_collection = pg_collection + self.cp_group = pg_collection.cp + self.tp_group = pg_collection.tp + self.pp_group = pg_collection.pp + assert hasattr(self.pg_collection, "embd"), ( + "pg_collection must have a embd. In previous version, it used default " + "`parallel_state.default_embedding_ranks` to create the process group." + "If you are using the default process group, please use" + "`parallel_state.get_embedding_group()` " + "If you don't need embd_group, you need to explicitly set it to None." + ) + self.embd_group = pg_collection.embd + self.vp_stage = None + self.vp_size = self.config.virtual_pipeline_model_parallel_size + + if self.pre_process: + # Use HF vision encoder directly (ReplicatedMapping in bridge) + self.visual = Qwen2_5OmniVisionEncoderHF._from_config(thinker_transformer_config.vision_config) + hook_hf_module_setattr_for_tp_grad_sync(self.visual) + + # Use HF audio encoder directly (ReplicatedMapping in bridge) + self.audio_model = Qwen2_5OmniAudioEncoderHF._from_config(thinker_transformer_config.audio_config) + hook_hf_module_setattr_for_tp_grad_sync(self.audio_model) + + self.language_model = Qwen3VLGPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_transformer_config.vocab_size, + max_sequence_length=language_transformer_config.language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_transformer_config.rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_transformer_config.rotary_base, + fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + pg_collection=pg_collection, + ) + + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + if self.add_decoder: + return self.language_model.shared_embedding_or_output_weight() + return None + + def set_input_tensor(self, input_tensor) -> None: + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen25Omni" + + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze( + self, + freeze_language_model: bool = False, + freeze_vision_model: bool = False, + freeze_audio_model: bool = False, + ): + """Freeze model modules. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_audio_model (bool): Freeze the audio model module. + """ + modules = [] + + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + + if freeze_vision_model and self.visual is not None: + modules.append(self.visual) + + if freeze_audio_model and self.audio_model is not None: + modules.append(self.audio_model) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + def get_audio_features( + self, + input_features: torch.FloatTensor, + feature_attention_mask: torch.LongTensor | None = None, + audio_feature_lengths: torch.LongTensor | None = None, + ): + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + + if audio_feature_lengths is None: + raise ValueError("Either feature_attention_mask or audio_feature_lengths must be provided") + + feature_lens = audio_feature_lengths + audio_feat_lengths, audio_output_lengths = self.audio_model._get_feat_extract_output_lengths(feature_lens) + + audio_outputs = self.audio_model( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + + return audio_outputs.last_hidden_state + + def forward( + self, + input_ids: torch.Tensor, + input_features=None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + loss_mask: torch.Tensor | None = None, + inference_params: InferenceParams | None = None, + packed_seq_params: PackedSeqParams | None = None, + extra_block_kwargs: dict | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + image_input_mask: torch.Tensor | None = None, + video_input_mask: torch.Tensor | None = None, + feature_attention_mask=None, + audio_feature_lengths=None, + cp_img_num: list[int] | None = None, + images_padded: list[bool] | None = None, + use_audio_in_video=None, + video_second_per_grid=None, + **kwargs, + ) -> torch.Tensor: + if inference_params is not None: + raise NotImplementedError("inference is not supported") + if packed_seq_params is not None: + raise NotImplementedError("packed_seq_params is not supported") + + cp_rank = self.pg_collection.cp.rank() + cp_size = self.pg_collection.cp.size() + + if self.pre_process: + # Run HF vision encoder to get vision embeddings (no deepstack) + vision_embeds = None + vision_mask = None + if pixel_values is not None or pixel_values_videos is not None: + # Build vision mask from input_ids + image_mask = input_ids == self.image_token_id + video_mask = input_ids == self.video_token_id + vision_mask = image_mask | video_mask + + # Process images through vision encoder + if pixel_values is not None and image_grid_thw is not None: + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + else: + image_embeds = None + + # Process videos through vision encoder + if pixel_values_videos is not None and video_grid_thw is not None: + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + else: + video_embeds = None + + # Combine image and video embeddings + if image_embeds is not None and video_embeds is not None: + vision_embeds = torch.cat([image_embeds, video_embeds], dim=0) + elif image_embeds is not None: + vision_embeds = image_embeds + elif video_embeds is not None: + vision_embeds = video_embeds + + # Extract audio features + audio_embeds = None + if input_features is not None: + audio_embeds = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + audio_mask = input_ids == self.audio_token_id + + # Get text embeddings from language model + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, + ).clone() # [text_seq_len, b, h_language] + + # Replace vision/audio token positions with vision_embeds/audio_embeds + if vision_embeds is not None or audio_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + if vision_embeds is not None: + combined_embeddings[vision_mask] = vision_embeds + if audio_embeds is not None: + combined_embeddings[audio_mask] = audio_embeds + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + if combined_embeddings is not None and cp_size > 1 and packed_seq_params is None: + combined_embeddings = split_data_cp_rank(combined_embeddings, cp_size, 0, cp_rank) + + # Track SP padding amount for position_ids alignment + sp_pad_len = 0 + if self.config.sequence_parallel: + tp_size = self.pg_collection.tp.size() + seq_len = combined_embeddings.shape[0] + sp_pad_len = (tp_size - seq_len % tp_size) % tp_size + if sp_pad_len > 0: + combined_embeddings = torch.nn.functional.pad(combined_embeddings, (0, 0, 0, 0, 0, sp_pad_len)) + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) + combined_embeddings = combined_embeddings.contiguous() + else: + combined_embeddings = None + sp_pad_len = 0 + + # Compute audio feature lengths for rope computation + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + else: + audio_feature_lengths = None + + # Compute position IDs via get_rope_index if not provided + if position_ids is None: + position_ids, _ = get_rope_index( + self.config.spatial_merge_size, + self.image_token_id, + self.video_token_id, + self.audio_token_id, + self.vision_start_token_id, + self.audio_start_token_id, + self.position_id_per_seconds, + self.seconds_per_chunk, + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + use_audio_in_video=use_audio_in_video, + audio_seqlens=audio_feature_lengths, + second_per_grids=video_second_per_grid, + ) + + # Pad position_ids to match SP-padded embeddings so rotary_pos_emb + # has the same sequence length as the all-gathered query/key tensors. + if sp_pad_len > 0 and position_ids is not None: + # position_ids shape: [3, batch, seq_len] → pad last dim + position_ids = torch.nn.functional.pad(position_ids, (0, sp_pad_len), mode="replicate") + + # No deepstack for Qwen2.5 Omni + output = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + decoder_input=combined_embeddings, + labels=labels, + loss_mask=loss_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + visual_pos_masks=None, + deepstack_visual_embeds=None, + **(extra_block_kwargs or {}), + ) + + return output diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/transformer_config.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/transformer_config.py new file mode 100644 index 0000000000..49a3fe5528 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/transformer_config.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from megatron.core.transformer.transformer_config import TransformerConfig + + +@dataclass +class Qwen25OmniTransformerConfig(TransformerConfig): + """Configuration for Qwen2.5 Omni transformer with vision, audio, and language components.""" + + vocab_size: int = 152064 + language_max_sequence_length: int = 4096 + + patch_size: int = 14 + temporal_patch_size: int = 2 + in_channels: int = 3 + spatial_merge_size: int = 2 + + apply_rotary_pos_emb_in_fp32: bool = False + fp16_lm_cross_entropy: bool = False + share_embeddings_and_output_weights: bool = False + rotary_percent: float = 1.0 + rotary_base: float = 10000 + + # Multimodal rope section for [temporal, height, width] dimensions + mrope_section: list[int] = field(default_factory=lambda: [16, 24, 24]) + apply_rope_fusion: bool = False + + image_token_id: int = 151655 + video_token_id: int = 151656 + audio_token_id: int = 151646 + vision_start_token_id: int = 151652 + audio_start_token_id: int = 151647 + position_id_per_seconds: int = 25 + seconds_per_chunk: int = 2 + + qk_layernorm: bool = False diff --git a/src/megatron/bridge/models/qwen_omni/qwen25_omni_bridge.py b/src/megatron/bridge/models/qwen_omni/qwen25_omni_bridge.py new file mode 100644 index 0000000000..c2258cf003 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/qwen25_omni_bridge.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import Qwen2_5OmniForConditionalGeneration + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, + ReplicatedMapping, +) +from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.model import Qwen25OmniModel +from megatron.bridge.models.qwen_omni.qwen25_omni_provider import Qwen25OmniModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen2_5OmniForConditionalGeneration, target=Qwen25OmniModel) +class Qwen25OmniBridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen2.5-Omni Conditional Generation. + + Handles conversion between HuggingFace Qwen2_5OmniForConditionalGeneration + and Megatron-Core Qwen25OmniModel formats. + + Key differences from Qwen3OmniMoeBridge: + - Dense LLM (Qwen2), not MoE -> no router/expert mappings + - QKV bias mappings (Qwen2 has attention bias) + - No QK layernorm weight mappings + - Vision: ReplicatedMapping for HF vision encoder (thinker.visual.**) + - Audio: ReplicatedMapping for HF audio encoder (thinker.audio_model.** -> thinker.audio_tower.**) + - LLM layer norms use mlp.linear_fc1.layer_norm_weight (not pre_mlp_layernorm) + """ + + def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen25OmniModelProvider: + """Create a Qwen25OmniModelProvider from a HuggingFace pretrained model.""" + hf_config = hf_pretrained.config + thinker_config = hf_config.thinker_config + talker_config = hf_config.talker_config + token2wav_config = hf_config.token2wav_config + text_config = thinker_config.text_config + model_dtype = self.dtype_from_hf(thinker_config, default=torch.float32) + + provider = Qwen25OmniModelProvider( + thinker_config=thinker_config, + talker_config=talker_config, + token2wav_config=token2wav_config, + num_layers=text_config.num_hidden_layers, + hidden_size=text_config.hidden_size, + ffn_hidden_size=text_config.intermediate_size, + num_attention_heads=text_config.num_attention_heads, + num_query_groups=text_config.num_key_value_heads, + head_dim=getattr(text_config, "head_dim", text_config.hidden_size // text_config.num_attention_heads), + init_method_std=text_config.initializer_range, + layernorm_epsilon=text_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(text_config.vocab_size), + rotary_base=getattr(text_config, "rope_theta", 1000000), + share_embeddings_and_output_weights=getattr(text_config, "tie_word_embeddings", False), + vocab_size=text_config.vocab_size, + seq_length=text_config.max_position_embeddings, + fp16=(model_dtype == torch.float16), + bf16=(model_dtype == torch.bfloat16), + params_dtype=model_dtype, + add_qkv_bias=True, # Qwen2 always has QKV bias + qk_layernorm=False, # Qwen2 has no QK layernorm + # Token IDs from thinker config + image_token_id=getattr(thinker_config, "image_token_index", 151655), + video_token_id=getattr(thinker_config, "video_token_index", 151656), + audio_token_id=getattr(thinker_config, "audio_token_index", 151646), + vision_start_token_id=getattr(thinker_config, "vision_start_token_id", 151652), + audio_start_token_id=getattr(thinker_config, "audio_start_token_id", 151647), + audio_end_token_id=getattr(thinker_config, "audio_end_token_id", 151648), + mrope_section=(getattr(text_config, "rope_scaling", None) or {}).get("mrope_section", [16, 24, 24]), + position_id_per_seconds=getattr(thinker_config, "position_id_per_seconds", 25), + seconds_per_chunk=getattr(thinker_config, "seconds_per_chunk", 2), + ) + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + """Return MegatronMappingRegistry containing parameter mappings for dense Qwen2.5 Omni models.""" + # LLM parameter mappings (same pattern as Qwen25VL bridge but prefixed with thinker.) + param_mappings = { + # Embeddings and output layers + "thinker.language_model.embedding.word_embeddings.weight": "thinker.model.embed_tokens.weight", + "thinker.language_model.output_layer.weight": "thinker.lm_head.weight", + "thinker.language_model.decoder.final_layernorm.weight": "thinker.model.norm.weight", + # Layer normalization + "thinker.language_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "thinker.model.layers.*.input_layernorm.weight", + "thinker.language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "thinker.model.layers.*.post_attention_layernorm.weight", + # Attention output projection + "thinker.language_model.decoder.layers.*.self_attention.linear_proj.weight": "thinker.model.layers.*.self_attn.o_proj.weight", + # MLP down projection + "thinker.language_model.decoder.layers.*.mlp.linear_fc2.weight": "thinker.model.layers.*.mlp.down_proj.weight", + } + + mapping_list = [] + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + mapping_list.extend( + [ + # Vision: ReplicatedMapping (HF vision encoder used directly) + ReplicatedMapping( + megatron_param="thinker.visual.**", + hf_param="thinker.visual.**", + ), + # Audio: ReplicatedMapping (HF audio encoder used directly) + # HF uses thinker.audio_tower, Megatron uses thinker.audio_model + ReplicatedMapping( + megatron_param="thinker.audio_model.**", + hf_param="thinker.audio_tower.**", + ), + # QKV weight: Combine separate Q, K, V weights into single QKV matrix + QKVMapping( + megatron_param="thinker.language_model.decoder.layers.*.self_attention.linear_qkv.weight", + q="thinker.model.layers.*.self_attn.q_proj.weight", + k="thinker.model.layers.*.self_attn.k_proj.weight", + v="thinker.model.layers.*.self_attn.v_proj.weight", + ), + # QKV bias: Combine separate Q, K, V biases into single QKV bias (Qwen2 has attention bias) + QKVMapping( + megatron_param="thinker.language_model.decoder.layers.*.self_attention.linear_qkv.bias", + q="thinker.model.layers.*.self_attn.q_proj.bias", + k="thinker.model.layers.*.self_attn.k_proj.bias", + v="thinker.model.layers.*.self_attn.v_proj.bias", + ), + # Gated MLP: Combine gate and up projection matrices into single FC1 matrix + GatedMLPMapping( + megatron_param="thinker.language_model.decoder.layers.*.mlp.linear_fc1.weight", + gate="thinker.model.layers.*.mlp.gate_proj.weight", + up="thinker.model.layers.*.mlp.up_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/src/megatron/bridge/models/qwen_omni/qwen25_omni_provider.py b/src/megatron/bridge/models/qwen_omni/qwen25_omni_provider.py new file mode 100644 index 0000000000..41b8f71604 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/qwen25_omni_provider.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Qwen2.5 Omni Model Provider configurations for Megatron-Core. + +This module provides configuration classes for Qwen2.5 Omni multimodal models +(audio+vision+text), compatible with HuggingFace's Qwen2.5-Omni model configurations. +Reference: https://huggingface.co/Qwen/Qwen2.5-Omni-7B +""" + +from dataclasses import dataclass, field + +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniTalkerConfig, + Qwen2_5OmniThinkerConfig, + Qwen2_5OmniToken2WavConfig, +) + +from megatron.bridge.models import Qwen2ModelProvider +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.model import Qwen25OmniModel + + +@dataclass +class Qwen25OmniModelProvider(Qwen2ModelProvider): + """ + Base model provider for Qwen2.5 Omni Models. + Inherits language model configuration from Qwen2ModelProvider (dense, Qwen2 architecture). + + Key differences from Qwen3OmniMoeModelProvider: + - Dense LLM (Qwen2), not MoE + - Has QKV bias (Qwen2 specific), no QK layernorm + - mrope_section: [16, 24, 24] (not [24, 20, 20]) + - position_id_per_seconds: 25 (not 13) + - seconds_per_chunk: 2 for audio-in-video + - patch_size: 14 (not 16) + - Uses HF vision model directly (ReplicatedMapping) + """ + + thinker_config: Qwen2_5OmniThinkerConfig = field(default_factory=lambda: Qwen2_5OmniThinkerConfig()) + talker_config: Qwen2_5OmniTalkerConfig | None = None + token2wav_config: Qwen2_5OmniToken2WavConfig | None = None + + pretrained_model_name: str = "Qwen/Qwen2.5-Omni-7B" + + # Token IDs matching Qwen2.5-Omni configuration + image_token_id: int = 151655 + video_token_id: int = 151656 + audio_token_id: int = 151646 + vision_start_token_id: int = 151652 + vision_end_token_id: int = 151653 + audio_start_token_id: int = 151647 + audio_end_token_id: int = 151648 + bos_token_id: int = 151643 + eos_token_id: int = 151645 + + head_dim: int = 128 + add_qkv_bias: bool = True + qk_layernorm: bool = False + attention_softmax_in_fp32: bool = True + attention_dropout: float = 0.0 + + position_embedding_type: str = "mrope" + apply_rotary_pos_emb_in_fp32: bool = False + mrope_section: list[int] = field(default_factory=lambda: [16, 24, 24]) + rotary_base: float = 1000000 + spatial_merge_size: int = 2 + temporal_patch_size: int = 2 + patch_size: int = 14 + + scatter_embedding_sequence_parallel: bool = False + + position_id_per_seconds: int = 25 + seconds_per_chunk: int = 2 + + # Freeze options + freeze_language_model: bool = False + freeze_vision_model: bool = False + freeze_audio_model: bool = False + language_max_sequence_length: int = 2048 + + persist_layer_norm: bool = True + bias_activation_fusion: bool = True + bias_dropout_fusion: bool = True + masked_softmax_fusion: bool = False + deallocate_pipeline_outputs: bool = True + async_tensor_model_parallel_allreduce: bool = True + distribute_saved_activations: bool = False + cp_comm_type: str = "p2p" + + def provide(self, pre_process=None, post_process=None, vp_stage=None): + """Provide a Qwen2.5 Omni model instance with vision, audio, and language components.""" + language_transformer_config = self + thinker_config = self.thinker_config + talker_config = self.talker_config + token2wav_config = self.token2wav_config + + # Dense GPT layer spec (no MoE, no QK layernorm for Qwen2) + language_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=None, + moe_grouped_gemm=False, + qk_layernorm=self.qk_layernorm, + fp8=False, + ) + + model = Qwen25OmniModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_transformer_layer_spec, + thinker_transformer_config=thinker_config, + talker_transformer_config=talker_config, + token2wav_transformer_config=token2wav_config, + pre_process=pre_process, + post_process=post_process, + pg_collection=self._pg_collection, + ) + + if self.freeze_language_model or self.freeze_vision_model or self.freeze_audio_model: + model.freeze( + freeze_language_model=self.freeze_language_model, + freeze_vision_model=self.freeze_vision_model, + freeze_audio_model=self.freeze_audio_model, + ) + + return model + + def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Provide just the language model component without vision/audio.""" + return super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) diff --git a/tests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py b/tests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py new file mode 100644 index 0000000000..82ae3261c2 --- /dev/null +++ b/tests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py @@ -0,0 +1,415 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for Qwen2.5 Omni Model implementation. + +Run with: torchrun --nproc_per_node=8 -m pytest tests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py +Or for single GPU: pytest tests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py +""" + +import datetime +import os + +import numpy as np +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F +from megatron.core import parallel_state +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from transformers import AutoConfig, AutoProcessor + +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.model import Qwen25OmniModel +from megatron.bridge.models.qwen_omni.modeling_qwen25_omni.transformer_config import Qwen25OmniTransformerConfig + + +@pytest.fixture(scope="module") +def processor(): + """Load HuggingFace processor once for all tests.""" + return AutoProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") + + +@pytest.fixture(scope="module") +def hf_config(): + """Load HuggingFace config once for all tests.""" + return AutoConfig.from_pretrained("Qwen/Qwen2.5-Omni-7B") + + +@pytest.fixture +def random_image(): + """Generate a random PIL image.""" + return np.random.randint(0, 255, size=(24, 24, 3), dtype=np.uint8) + + +@pytest.fixture +def random_video(): + """Generate a random video.""" + return np.random.randint(0, 255, size=(2, 3, 24, 44), dtype=np.uint8) + + +@pytest.fixture +def random_audio(): + """Generate a random audio.""" + return np.random.randint(-1, 32767, size=(800), dtype=np.int16) + + +class TestQwen25OmniModel: + """Test suite for Qwen2.5 Omni Model.""" + + @classmethod + def setup_class(cls): + """Setup distributed process group once for all tests in this class.""" + if not dist.is_initialized(): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + device_count = torch.cuda.device_count() + if device_count > 0: + torch.cuda.set_device(0) + + dist.init_process_group( + backend="nccl" if device_count > 0 else "gloo", + world_size=1, + rank=0, + timeout=datetime.timedelta(minutes=30), + ) + + @classmethod + def teardown_class(cls): + """Teardown distributed process group once after all tests in this class.""" + if dist.is_initialized(): + dist.destroy_process_group() + + def _setup_parallel_state(self, tp_size=1, pp_size=1, cp_size=1): + """Setup Megatron parallel state with specified parallelism configuration. + + Args: + tp_size: Tensor model parallel size + pp_size: Pipeline model parallel size + cp_size: Context parallel size + """ + # Clean up any existing parallel state before initializing + if parallel_state.model_parallel_is_initialized(): + parallel_state.destroy_model_parallel() + + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=cp_size, + ) + + model_parallel_cuda_manual_seed(123) + + def teardown_method(self): + """Teardown Megatron parallel state after each test method.""" + parallel_state.destroy_model_parallel() + + @staticmethod + def get_thinker_transformer_config(hf_config): + """Create a thinker transformer config for testing. + + Returns: + Qwen2_5OmniThinkerConfig: HF configuration for the thinker model. + """ + return hf_config.thinker_config + + @staticmethod + def get_language_transformer_config(hf_config): + """Create a language transformer config for testing. + + Uses actual Qwen2.5-Omni-7B model sizes to ensure compatibility + with the vision/audio model outputs. + + Args: + hf_config: HuggingFace config object. + + Returns: + Qwen25OmniTransformerConfig: Configuration for the language model. + """ + thinker_config = hf_config.thinker_config + rope_scaling = getattr(thinker_config.text_config, "rope_scaling", None) + if rope_scaling is not None: + mrope_section = rope_scaling.get("mrope_section", [16, 24, 24]) + else: + mrope_section = [16, 24, 24] + + return Qwen25OmniTransformerConfig( + # Use actual model dimensions from HF config + num_layers=4, # Reduced for testing (actual: thinker_config.text_config.num_hidden_layers) + hidden_size=thinker_config.text_config.hidden_size, + num_attention_heads=thinker_config.text_config.num_attention_heads, + num_query_groups=thinker_config.text_config.num_key_value_heads, + kv_channels=thinker_config.text_config.hidden_size // thinker_config.text_config.num_attention_heads, + ffn_hidden_size=thinker_config.text_config.intermediate_size, + # Qwen2.5-Omni specific + vocab_size=thinker_config.text_config.vocab_size, + language_max_sequence_length=thinker_config.text_config.max_position_embeddings, + # Vision parameters + patch_size=thinker_config.vision_config.patch_size, + temporal_patch_size=thinker_config.vision_config.temporal_patch_size, + in_channels=thinker_config.vision_config.in_channels, + spatial_merge_size=thinker_config.vision_config.spatial_merge_size, + # RoPE settings + rotary_base=getattr(thinker_config.text_config, "rope_theta", 1000000), + rotary_percent=1.0, + mrope_section=mrope_section, + # Training settings + normalization="RMSNorm", + activation_func=F.silu, + gated_linear_unit=True, + add_bias_linear=False, + add_qkv_bias=True, + qk_layernorm=False, + layernorm_epsilon=thinker_config.text_config.rms_norm_eps, + bf16=False, + use_cpu_initialization=True, + hidden_dropout=0.0, + attention_dropout=thinker_config.text_config.attention_dropout, + ) + + @staticmethod + def get_language_model_layer_spec(): + """Create a GPT layer spec for the language model (dense, no MoE). + + Returns: + ModuleSpec: Layer specification for transformer layers. + """ + language_model_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=None, + moe_grouped_gemm=False, + qk_layernorm=False, + fp8=False, + ) + return language_model_layer_spec + + @staticmethod + def get_data_batch(processor, random_image, random_video, random_audio): + """Generate a batch of data for model forward pass. + + Args: + processor: HuggingFace processor. + random_image: Random PIL image. + random_video: Random video. + random_audio: Random audio. + Returns: + dict: A dictionary containing all inputs needed for model forward pass: + - input_ids: Token IDs [batch, seq_len] + - attention_mask: Attention mask [batch, seq_len] + - pixel_values: Image pixel values [batch, channels, height, width] + - image_grid_thw: Image grid dimensions [num_images, 3] (temporal, height, width) + - pixel_values_videos: Video pixel values (None for images only) + - video_grid_thw: Video grid dimensions (None for images only) + - input_features: Audio values (None if no audio) + - feature_attention_mask: Audio attention mask (None if no audio) + - video_second_per_grid: Video seconds per grid + """ + # Create a sample message with image and text + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": random_image, + }, + { + "type": "video", + "video": random_video, + }, + { + "type": "audio", + "audio": random_audio, + }, + {"type": "text", "text": "Describe this image, video and audio."}, + ], + } + ] + + # Process inputs using HuggingFace processor + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + + batch = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs.get("attention_mask"), + "pixel_values": inputs.get("pixel_values"), + "image_grid_thw": inputs.get("image_grid_thw"), + "pixel_values_videos": inputs.get("pixel_values_videos"), + "video_grid_thw": inputs.get("video_grid_thw"), + "input_features": inputs.get("input_features"), + "feature_attention_mask": inputs.get("feature_attention_mask"), + "video_second_per_grid": inputs.get("video_second_per_grid"), + "position_ids": None, + "labels": None, + } + + # Move tensors to CUDA if available + if torch.cuda.is_available(): + for key, value in batch.items(): + if value is not None and isinstance(value, torch.Tensor): + batch[key] = value.cuda() + + return batch + + @pytest.mark.timeout(50) + @pytest.mark.parametrize( + "freeze_all", + [True, False], + ) + def test_model_freeze_api(self, freeze_all, hf_config): + """Test model freeze API.""" + self._setup_parallel_state(tp_size=1, pp_size=1) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + assert pg_collection is not None + assert pg_collection.tp is not None + assert pg_collection.pp is not None + assert pg_collection.cp is not None + assert pg_collection.embd is not None + + language_transformer_config = self.get_language_transformer_config(hf_config) + language_model_layer_spec = self.get_language_model_layer_spec() + thinker_transformer_config = self.get_thinker_transformer_config(hf_config) + + model = Qwen25OmniModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True, + pg_collection=pg_collection, + ) + + if torch.cuda.is_available(): + model.to("cuda") + + model.freeze( + freeze_language_model=freeze_all, + freeze_vision_model=freeze_all, + freeze_audio_model=freeze_all, + ) + + for name, param in model.named_parameters(): + assert param.requires_grad != freeze_all, f"{name=}" + + @pytest.mark.timeout(50) + def test_shared_embedding_or_output_weight(self, hf_config): + """Test shared_embedding_or_output_weight method.""" + self._setup_parallel_state(tp_size=1, pp_size=1) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + assert pg_collection is not None + assert pg_collection.tp is not None + assert pg_collection.pp is not None + assert pg_collection.cp is not None + assert pg_collection.embd is not None + + language_transformer_config = self.get_language_transformer_config(hf_config) + language_model_layer_spec = self.get_language_model_layer_spec() + thinker_transformer_config = self.get_thinker_transformer_config(hf_config) + + # Test with add_decoder=True + model = Qwen25OmniModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True, + pg_collection=pg_collection, + ) + weight = model.shared_embedding_or_output_weight() + assert weight is not None + + # Test with add_decoder=False + model = Qwen25OmniModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=False, + pg_collection=pg_collection, + ) + weight_no_decoder = model.shared_embedding_or_output_weight() + assert weight_no_decoder is None + + @pytest.mark.timeout(50) + def test_set_input_tensor(self, hf_config): + """Test set_input_tensor method.""" + self._setup_parallel_state(tp_size=1, pp_size=1) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + assert pg_collection is not None + assert pg_collection.tp is not None + assert pg_collection.pp is not None + assert pg_collection.cp is not None + assert pg_collection.embd is not None + + language_transformer_config = self.get_language_transformer_config(hf_config) + language_model_layer_spec = self.get_language_model_layer_spec() + thinker_transformer_config = self.get_thinker_transformer_config(hf_config) + hidden_size = language_transformer_config.hidden_size + + # Test with pre_process=True + model = Qwen25OmniModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True, + pg_collection=pg_collection, + ) + + test_tensor = torch.randn(2, 4, hidden_size) + + # Test with single tensor (not a list) + model.set_input_tensor([test_tensor]) + assert model.thinker.encoder_hidden_state is not None + + # Test with pre_process=False + model = Qwen25OmniModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + parallel_output=True, + pre_process=False, + post_process=True, + add_encoder=True, + add_decoder=True, + pg_collection=pg_collection, + ) + + # This should set the input tensor on the language model instead + model.set_input_tensor([test_tensor]) + # No assertion here as it sets internal state