diff --git a/src/megatron/bridge/models/__init__.py b/src/megatron/bridge/models/__init__.py index 8cb0e8ffaa..516aa40286 100644 --- a/src/megatron/bridge/models/__init__.py +++ b/src/megatron/bridge/models/__init__.py @@ -178,6 +178,11 @@ Qwen25ModelProvider72B, Qwen25ModelProvider500M, ) +from megatron.bridge.models.qwen_omni import ( + Qwen3OmniMoeBridge, + Qwen3OmniMoeModel, + Qwen3OmniMoeModelProvider, +) from megatron.bridge.models.qwen_vl import ( Qwen25VLBridge, Qwen25VLModel, @@ -350,6 +355,10 @@ "NemotronVLBridge", "NemotronNano12Bv2Provider", "NemotronNano12Bv2VLModelProvider", + # Omni Models + "Qwen3OmniMoeModel", + "Qwen3OmniMoeBridge", + "Qwen3OmniMoeModelProvider", "SarvamMLABridge", "SarvamMoEBridge", ] diff --git a/src/megatron/bridge/models/qwen_omni/__init__.py b/src/megatron/bridge/models/qwen_omni/__init__.py new file mode 100644 index 0000000000..3524d8aa4b --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.model import Qwen3OmniMoeModel +from megatron.bridge.models.qwen_omni.qwen3_omni_bridge import Qwen3OmniMoeBridge +from megatron.bridge.models.qwen_omni.qwen3_omni_provider import Qwen3OmniMoeModelProvider + + +__all__ = [ + "Qwen3OmniMoeModel", + "Qwen3OmniMoeBridge", + "Qwen3OmniMoeModelProvider", +] diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/model.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/model.py new file mode 100644 index 0000000000..fb194cb8e6 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/model.py @@ -0,0 +1,145 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import InferenceParams +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeCode2WavConfig, + Qwen3OmniMoeTalkerConfig, + Qwen3OmniMoeThinkerConfig, +) + +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.thinker_model import Qwen3OmniMoeThinkerModel +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.transformer_config import Qwen3OmniTransformerConfig + + +class Qwen3OmniMoeModel(MegatronModule): + """Qwen3 Omni Moe Model.""" + + def __init__( + self, + language_transformer_config: Qwen3OmniTransformerConfig, + language_transformer_layer_spec: ModuleSpec, + thinker_transformer_config: Qwen3OmniMoeThinkerConfig, + talker_transformer_config: Qwen3OmniMoeTalkerConfig | None = None, + code2wav_transformer_config: Qwen3OmniMoeCode2WavConfig | None = None, + parallel_output: bool = True, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + pg_collection: ProcessGroupCollection = None, + ) -> None: + super().__init__(config=language_transformer_config) + + self.thinker = Qwen3OmniMoeThinkerModel( + language_transformer_config, + language_transformer_layer_spec, + thinker_transformer_config, + parallel_output, + pre_process, + post_process, + add_encoder, + add_decoder, + pg_collection, + ) + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + return self.thinker.shared_embedding_or_output_weight() + + def set_input_tensor(self, input_tensor) -> None: + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + return self.thinker.set_input_tensor(input_tensor) + + def freeze( + self, + freeze_language_model: bool = False, + freeze_vision_model: bool = False, + freeze_vision_projection: bool = False, + freeze_audio_model: bool = False, + ): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection modules (merger and deepstack_merger_list). + freeze_audio_model (bool): Freeze the audio model module. + """ + return self.thinker.freeze( + freeze_language_model, + freeze_vision_model, + freeze_vision_projection, + freeze_audio_model, + ) + + def forward( + self, + input_ids: torch.Tensor, + input_features=None, + position_ids: torch.Tensor = None, # can set at dataset + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + loss_mask: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + # cat set at dataset + image_input_mask: torch.Tensor = None, + video_input_mask: torch.Tensor = None, + feature_attention_mask=None, + audio_feature_lengths=None, + cp_img_num: list[int] = None, + images_padded: list[bool] = None, + use_audio_in_video=None, + video_second_per_grid=None, + **kwargs, + ) -> torch.Tensor: + return self.thinker( + input_ids=input_ids, + input_features=input_features, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + extra_block_kwargs=extra_block_kwargs, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_input_mask=image_input_mask, + video_input_mask=video_input_mask, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + cp_img_num=cp_img_num, + images_padded=images_padded, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/rope.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/rope.py new file mode 100644 index 0000000000..91b7022060 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/rope.py @@ -0,0 +1,233 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def _get_feat_extract_output_lengths(input_lengths): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + +def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[torch.Tensor], + grid_hs: list[torch.Tensor], + grid_ws: list[torch.Tensor], +): + """get llm position ids""" + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten().float() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten().float() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().float() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + +def get_rope_index( + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + audio_token_id: int, + vision_start_token_id: int, + audio_start_token_id: int, + position_id_per_seconds: int, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + use_audio_in_video: bool = False, + audio_seqlens: torch.LongTensor | None = None, + second_per_grids: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + """ + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.zeros( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=torch.float, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums + multimodal_nums = image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if (image_token_id in input_tokens or video_token_id in input_tokens) and ( + remain_videos > 0 or remain_images > 0 + ): + ed_vision_start = input_tokens.index(vision_start_token_id, st) + else: + ed_vision_start = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio_start = input_tokens.index(audio_start_token_id, st) + else: + ed_audio_start = len(input_tokens) + 1 + min_ed = min(ed_vision_start, ed_audio_start) + + text_len = min_ed - st + if text_len != 0: + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + st_idx += text_len + # Audio in Video + if min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start: + bos_len, eos_len = 2, 2 + else: + bos_len, eos_len = 1, 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + st_idx += bos_len + # Audio Only + if min_ed == ed_audio_start: + audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + + st += int(text_len + bos_len + audio_len + eos_len) + audio_idx += 1 + remain_audios -= 1 + + # Image Only + elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).float() + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st += int(text_len + bos_len + image_len + eos_len) + image_idx += 1 + remain_images -= 1 + + # Video Only + elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == video_token_id: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).float() + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st += int(text_len + bos_len + video_len + eos_len) + video_idx += 1 + remain_videos -= 1 + + # Audio in Video + elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start: + audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).float() + video_llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_data_index, audio_data_index = 0, 0 + while ( + video_data_index < video_llm_pos_ids.shape[-1] + and audio_data_index < audio_llm_pos_ids.shape[-1] + ): + if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]: + llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1]) + video_data_index += 1 + else: + llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1]) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]]) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]]) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + + st += int(text_len + bos_len + audio_len + video_len + eos_len) + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat([item.float() for item in llm_pos_ids_list], dim=1).reshape(3, -1) + + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + + return position_ids, mrope_position_deltas + else: + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + attention_mask = attention_mask.to(input_ids.device) + position_ids = attention_mask.float().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + return position_ids, mrope_position_deltas diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/thinker_model.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/thinker_model.py new file mode 100644 index 0000000000..23172fd06b --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/thinker_model.py @@ -0,0 +1,428 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TENorm, + TERowParallelLinear, +) +from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeThinkerConfig as Qwen3OmniMoeThinkerConfigHF, +) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder as Qwen3OmniMoeAudioEncoderHF, +) + +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.rope import get_rope_index +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.transformer_config import Qwen3OmniTransformerConfig +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.attention import Qwen3VLSelfAttention +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import Qwen3VLGPTModel +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import get_vision_model_config +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import ( + AllGatherVisionEmbeddings, + PatchMergerSubmodules, + collapse_thw, + get_vision_cp_data, + qwen3vl_cp_split, + reorganize_inputs, + split_data_cp_rank, + split_deepstack_embs, +) +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.vision_model import Qwen3VLVisionModel +from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync + + +class Qwen3OmniMoeThinkerModel(MegatronModule): + """Qwen3 Omni Moe Thinker Model.""" + + def __init__( + self, + language_transformer_config: Qwen3OmniTransformerConfig, + language_transformer_layer_spec: ModuleSpec, + vision_transformer_config: Qwen3OmniMoeThinkerConfigHF, + parallel_output: bool = True, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + pg_collection: ProcessGroupCollection = None, + ) -> None: + super().__init__(config=language_transformer_config) + + language_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention + + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.encoder_hidden_state = None + self.vision_model = None + self.audio_model = None + self.language_model = None + self.image_token_id = language_transformer_config.image_token_id + self.video_token_id = language_transformer_config.video_token_id + self.audio_token_id = language_transformer_config.audio_token_id + self.vision_start_token_id = language_transformer_config.vision_start_token_id + self.audio_start_token_id = language_transformer_config.audio_start_token_id + self.position_id_per_seconds = language_transformer_config.position_id_per_seconds + + self.square_merge_size = vision_transformer_config.vision_config.spatial_merge_size**2 + + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. + self.share_embeddings_and_output_weights = False + # process groups + self.pg_collection = pg_collection + self.cp_group = pg_collection.cp + self.tp_group = pg_collection.tp + self.pp_group = pg_collection.pp + assert hasattr(self.pg_collection, "embd"), ( + "pg_collection must have a embd. In previous version, it used default " + "`parallel_state.default_embedding_ranks` to create the process group." + "If you are using the default process group, please use" + "`parallel_state.get_embedding_group()` " + "If you don't need embd_group, you need to explicitly set it to None." + ) + self.embd_group = pg_collection.embd + self.vp_stage = None + self.vp_size = self.config.virtual_pipeline_model_parallel_size + + if self.pre_process: + if language_transformer_config.use_hf_vision_model: + raise ValueError("use_hf_vision_model is not supported for Qwen3VLModel for now") + # use megatron vision model + vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() + vision_patch_merger_spec = PatchMergerSubmodules( + patch_norm=TENorm, + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ) + + vision_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention + megatron_vision_transformer_config = get_vision_model_config( + vision_transformer_config.vision_config, megatron_config=language_transformer_config + ) + megatron_vision_transformer_config.pipeline_model_parallel_size = 1 + megatron_vision_transformer_config.first_pipeline_num_layers = None + + self.vision_model = Qwen3VLVisionModel( + megatron_vision_transformer_config, + vision_transformer_layer_spec, + vision_patch_merger_spec, + pre_process=True, + post_process=True, + ) + + # Initialize audio model with random weights from config + self.audio_model = Qwen3OmniMoeAudioEncoderHF._from_config(vision_transformer_config.audio_config) + # Ensure HF audio tower params are marked for TP grad sync and future assignments are hooked. + hook_hf_module_setattr_for_tp_grad_sync(self.audio_model) + + self.language_model = Qwen3VLGPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_transformer_config.vocab_size, + max_sequence_length=language_transformer_config.language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_transformer_config.rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_transformer_config.rotary_base, + fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + pg_collection=pg_collection, + ) + assert len(vision_transformer_config.vision_config.deepstack_visual_indexes) < len( + self.language_model.decoder.layers + ), ( + "the deepstack_visual_embeds should on the first pp-stage", + f"got {len(vision_transformer_config.vision_config.deepstack_visual_indexes)} deepstack_visual_indexes, " + f" {len(self.language_model.decoder.layers)} language model layers", + ) + + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + if self.add_decoder: + return self.language_model.shared_embedding_or_output_weight() + return None + + def set_input_tensor(self, input_tensor) -> None: + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen3VL" + + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze( + self, + freeze_language_model: bool = False, + freeze_vision_model: bool = False, + freeze_vision_projection: bool = False, + freeze_audio_model: bool = False, + ): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection modules. + freeze_audio_model (bool): Freeze the audio model module. + """ + modules = [] + + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + + if freeze_vision_model and self.vision_model is not None: + modules.append(self.vision_model) + + if freeze_vision_projection and self.vision_model is not None: + modules.append(self.vision_model.decoder.deepstack_merger_list) + modules.append(self.vision_model.merger) + + if freeze_audio_model and self.audio_model is not None: + modules.append(self.audio_model) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + def get_audio_features( + self, + input_features: torch.FloatTensor, + feature_attention_mask: torch.LongTensor | None = None, + audio_feature_lengths: torch.LongTensor | None = None, + ): + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + + # feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + feature_lens = audio_feature_lengths + audio_outputs = self.audio_model( + input_features, + feature_lens=feature_lens, + ) + + return audio_outputs.last_hidden_state + + def forward( + self, + input_ids: torch.Tensor, + input_features=None, + position_ids: torch.Tensor = None, # can set at dataset + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + loss_mask: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + # cat set at dataset + image_input_mask: torch.Tensor = None, + video_input_mask: torch.Tensor = None, + feature_attention_mask=None, + audio_feature_lengths=None, + cp_img_num: list[int] = None, + images_padded: list[bool] = None, + use_audio_in_video=None, + video_second_per_grid=None, + **kwargs, + ) -> torch.Tensor: + assert inference_params is None, "not support inference" + assert packed_seq_params is None, "not support packed_seq_params" + + vision_grid_thw = None + vision_data = None + vision_mask = None + deepstack_feature_lists = None + + cp_rank = self.pg_collection.cp.rank() + cp_size = self.pg_collection.cp.size() + + if self.pre_process: + vision_data, vision_grid_thw, vision_mask = reorganize_inputs( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_input_mask=image_input_mask, + video_input_mask=video_input_mask, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + square_merge_size=self.square_merge_size, + ) + + vision_embeds = None + if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: + if cp_size > 1: + if cp_img_num is None: + assert images_padded is None + vision_data, vision_grid_thw, cp_img_num, images_padded = qwen3vl_cp_split( + cp_size, + vision_data, + vision_grid_thw, + ) + vision_data, vision_grid_thw, seqlen_on_cp_ranks = get_vision_cp_data( + vision_data, + vision_grid_thw, + self.square_merge_size, + cp_img_num, + images_padded, + cp_rank, + cp_size, + ) + vision_grid_thw = collapse_thw(vision_grid_thw) + + if vision_data.shape[0] > 0: + vision_embeds, deepstack_feature_lists = self.vision_model( + hidden_states=vision_data, + grid_thw=vision_grid_thw, + ) + else: + vision_embeds = torch.zeros( + (0, self.language_model.config.hidden_size), + device=vision_data.device, + dtype=torch.bfloat16, + ) + deepstack_feature_lists = [] + for _ in self.vision_model.config.deepstack_visual_indexes: + deepstack_feature_lists.append( + torch.zeros( + (0, self.language_model.config.hidden_size), + device=vision_data.device, + dtype=torch.bfloat16, + ) + ) + + if cp_size > 1: + vision_embeds = AllGatherVisionEmbeddings.apply( + vision_embeds, + seqlen_on_cp_ranks, + self.pg_collection.cp, + ) + for i in range(len(deepstack_feature_lists)): + deepstack_feature_lists[i] = AllGatherVisionEmbeddings.apply( + deepstack_feature_lists[i], + seqlen_on_cp_ranks, + self.pg_collection.cp, + ) + + audio_embeds = None + if input_features is not None: + audio_embeds = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + audio_mask = input_ids == self.audio_token_id + + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ).clone() # [text_seq_len, b, h_language] + + if vision_embeds is not None or audio_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + if vision_embeds is not None: + combined_embeddings[vision_mask] = vision_embeds + if audio_embeds is not None: + combined_embeddings[audio_mask] = audio_embeds + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + if combined_embeddings is not None and cp_size > 1 and packed_seq_params is None: + combined_embeddings = split_data_cp_rank(combined_embeddings, cp_size, 0, cp_rank) + if self.config.sequence_parallel: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) + combined_embeddings = combined_embeddings.contiguous() + else: + combined_embeddings = None + + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + else: + audio_feature_lengths = None + + if position_ids is None: + position_ids, _ = get_rope_index( + self.config.spatial_merge_size, + self.image_token_id, + self.video_token_id, + self.audio_token_id, + self.vision_start_token_id, + self.audio_start_token_id, + self.position_id_per_seconds, + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + use_audio_in_video=use_audio_in_video, + audio_seqlens=audio_feature_lengths, + second_per_grids=video_second_per_grid, + ) + + visual_pos_masks = vision_mask + deepstack_visual_embeds = deepstack_feature_lists + if self.config.sequence_parallel or cp_size > 1: + visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs( + visual_pos_masks, + deepstack_visual_embeds, + tp_size=self.pg_collection.tp.size(), + tp_rank=self.pg_collection.tp.rank(), + cp_size=cp_size, + cp_rank=self.pg_collection.cp.rank(), + sequence_parallel=self.config.sequence_parallel, + ) + + output = self.language_model( + input_ids=None, + position_ids=position_ids, # None in encoder + attention_mask=attention_mask, # None in encoder + decoder_input=combined_embeddings, # only not None in the first decoder PP stage + labels=labels, # only not None in the last decoder PP stage + loss_mask=loss_mask, + inference_params=inference_params, # currently always None + packed_seq_params=packed_seq_params, # currently always None + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **(extra_block_kwargs or {}), + ) + + return output diff --git a/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/transformer_config.py b/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/transformer_config.py new file mode 100644 index 0000000000..e316963b42 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/transformer_config.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List + +from megatron.core.transformer.transformer_config import TransformerConfig +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig + + +@dataclass +class Qwen3OmniTransformerConfig(TransformerConfig): + """Configuration for Qwen3 Omni transformer with vision and language components.""" + + vocab_size: int = 152064 + language_max_sequence_length: int = 4096 + + patch_size: int = 16 + temporal_patch_size: int = 2 + in_channels: int = 3 + spatial_merge_size: int = 2 + num_position_embeddings: int = 2304 + out_hidden_size: int = 2048 + + apply_rotary_pos_emb_in_fp32: bool = False + deepstack_visual_indexes: List[int] = field(default_factory=lambda: [8, 16, 24]) + fp16_lm_cross_entropy: bool = False + share_embeddings_and_output_weights: bool = False + rotary_percent: float = 1.0 + rotary_base: float = 10000 + + # Multimodal rope section for [temporal, height, width] dimensions + mrope_section: List[int] = field(default_factory=lambda: [24, 20, 20]) + apply_rope_fusion: bool = False + + image_token_id: int = 151655 + video_token_id: int = 151656 + audio_token_id: int = 151675 + vision_start_token_id: int = 151652 + audio_start_token_id: int = 151669 + position_id_per_seconds: int = 13 + hf_text_config: Qwen3VLTextConfig | None = None + vision_dp_when_cp: bool = False + use_hf_vision_model: bool = False diff --git a/src/megatron/bridge/models/qwen_omni/qwen3_omni_bridge.py b/src/megatron/bridge/models/qwen_omni/qwen3_omni_bridge.py new file mode 100644 index 0000000000..8420781541 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/qwen3_omni_bridge.py @@ -0,0 +1,246 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import Qwen3OmniMoeForConditionalGeneration + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + ConcatenatedQKVMapping, + GatedMLPMapping, + QKVMapping, + ReplicatedMapping, +) +from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.model import Qwen3OmniMoeModel +from megatron.bridge.models.qwen_omni.qwen3_omni_provider import Qwen3OmniMoeModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen3OmniMoeForConditionalGeneration, target=Qwen3OmniMoeModel) +class Qwen3OmniMoeBridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen3-Omni-Moe Conditional Generation. + + This bridge handles the conversion between HuggingFace Qwen3OmniMoeForConditionalGeneration + and Megatron-Core Qwen3OmniMoeModel formats, including weight mappings and + configuration translation for Omni Moe models. + + The weight mappings handle: + 1. Standard language model mappings (embeddings, layer norms, output) + 2. Vision model mappings + 3. QKV mappings with QK layernorm + 4. MoE-specific mappings: + - Router weights for expert selection + - Expert MLPs (multiple experts per layer) + - Pre-MLP layernorm + 5. Deepstack visual merger mappings + 6. Audio model mappings + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-Omni-30B-A3B-Instruct") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen3OmniMoeModelProvider: + """ + Create a Qwen3OmniMoeModelProvider from a HuggingFace pretrained model. + + Args: + hf_pretrained: HuggingFace pretrained VLM model + + Returns: + Qwen3OmniMoeModelProvider configured with the HF model's parameters + """ + hf_config = hf_pretrained.config + thinker_config = hf_config.thinker_config + talker_config = hf_config.talker_config + code2wav_config = hf_config.code2wav_config + + text_config = thinker_config.text_config + model_dtype = self.dtype_from_hf(thinker_config, default=torch.float32) + + provider = Qwen3OmniMoeModelProvider( + thinker_config=thinker_config, + talker_config=talker_config, + code2wav_config=code2wav_config, + num_layers=text_config.num_hidden_layers, + hidden_size=text_config.hidden_size, + ffn_hidden_size=text_config.intermediate_size, # Dense FFN size (for non-MoE layers if any) + moe_ffn_hidden_size=text_config.moe_intermediate_size, # Expert FFN size + num_attention_heads=text_config.num_attention_heads, + num_query_groups=text_config.num_key_value_heads, # GQA configuration + head_dim=getattr(text_config, "head_dim", text_config.hidden_size // text_config.num_attention_heads), + init_method_std=text_config.initializer_range, + layernorm_epsilon=text_config.rms_norm_eps, + gated_linear_unit=True, # Qwen3 MoE uses gated linear units + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(text_config.vocab_size), + rotary_base=getattr(text_config, "rope_theta", 1000000), # Default Qwen3 omni rope theta + share_embeddings_and_output_weights=getattr(text_config, "tie_word_embeddings", False), + vocab_size=text_config.vocab_size, + seq_length=text_config.max_position_embeddings, + fp16=(model_dtype == torch.float16), + bf16=(model_dtype == torch.bfloat16), + params_dtype=model_dtype, + # Qwen3 specific parameters + add_qkv_bias=text_config.attention_bias, # Qwen3 can have bias in QKV + qk_layernorm=True, # Qwen3 uses QK layernorm + # MoE specific parameters + num_moe_experts=text_config.num_experts, + moe_router_topk=text_config.num_experts_per_tok, + decoder_sparse_step=getattr(text_config, "decoder_sparse_step", 1), # Default to every layer being MoE + mlp_only_layers=getattr(text_config, "mlp_only_layers", []), # Default to all layers using MoE + # Store the original HF text config for RoPE initialization + hf_text_config=text_config, + # Vision-Language token IDs + bos_token_id=getattr(text_config, "bos_token_id", 151643), + eos_token_id=getattr(text_config, "eos_token_id", 151645), + vision_start_token_id=getattr(thinker_config, "vision_start_token_id", 151652), + vision_end_token_id=getattr(thinker_config, "vision_end_token_id", 151653), + audio_start_token_id=getattr(thinker_config, "audio_start_token_id", 151669), + audio_end_token_id=getattr(thinker_config, "audio_end_token_id", 151670), + image_token_id=getattr(thinker_config, "image_token_id", 151655), + video_token_id=getattr(thinker_config, "video_token_id", 151656), + audio_token_id=getattr(thinker_config, "audio_token_id", 151675), + tts_bos_token_id=getattr(hf_config, "tts_bos_token_id", 151672), + tts_eos_token_id=getattr(hf_config, "tts_eos_token_id", 151673), + tts_pad_token_id=getattr(hf_config, "tts_pad_token_id", 151671), + # MRoPE configuration for multimodal position embeddings + mrope_section=getattr(text_config, "rope_scaling", {}).get("mrope_section", [24, 20, 20]), + position_id_per_seconds=getattr(thinker_config, "position_id_per_seconds", 13), + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + """ + Return MegatronMappingRegistry containing parameter mappings for MoE models. + + The MoE mappings include: + 1. Standard language model mappings (embeddings, layer norms, output) + 2. Vision model mappings + 3. QKV mappings with QK layernorm + 4. MoE-specific mappings: + - Router weights for expert selection + - Expert MLPs (multiple experts per layer) + - Pre-MLP layernorm + 5. Deepstack visual merger mappings + 6. Audio model mappings + + Returns: + MegatronMappingRegistry with all MoE parameter mappings + """ + # Language model direct mappings (same as dense model) + param_mappings = { + # Embeddings and output layers + "thinker.language_model.embedding.word_embeddings.weight": "thinker.model.embed_tokens.weight", + "thinker.language_model.output_layer.weight": "thinker.lm_head.weight", + "thinker.language_model.decoder.final_layernorm.weight": "thinker.model.norm.weight", + # Layer normalization for attention + "thinker.language_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "thinker.model.layers.*.input_layernorm.weight", + # MoE-specific: pre-MLP layernorm + "thinker.language_model.decoder.layers.*.pre_mlp_layernorm.weight": "thinker.model.layers.*.post_attention_layernorm.weight", + # Attention output projection + "thinker.language_model.decoder.layers.*.self_attention.linear_proj.weight": "thinker.model.layers.*.self_attn.o_proj.weight", + # QK layernorm weights (Qwen3 specific) + "thinker.language_model.decoder.layers.*.self_attention.q_layernorm.weight": "thinker.model.layers.*.self_attn.q_norm.weight", + "thinker.language_model.decoder.layers.*.self_attention.k_layernorm.weight": "thinker.model.layers.*.self_attn.k_norm.weight", + # MoE router weights + "thinker.language_model.decoder.layers.*.mlp.router.weight": "thinker.model.layers.*.mlp.gate.weight", + # MLP output projection + "thinker.language_model.decoder.layers.*.mlp.experts.linear_fc2.weight*": "thinker.model.layers.*.mlp.experts.*.down_proj.weight", + # vision module attn + "thinker.vision_model.decoder.layers.*.self_attention.linear_proj.weight": "thinker.visual.blocks.*.attn.proj.weight", + "thinker.vision_model.decoder.layers.*.self_attention.linear_proj.bias": "thinker.visual.blocks.*.attn.proj.bias", + # vision module mlp + "thinker.vision_model.decoder.layers.*.mlp.linear_fc1.weight": "thinker.visual.blocks.*.mlp.linear_fc1.weight", + "thinker.vision_model.decoder.layers.*.mlp.linear_fc1.bias": "thinker.visual.blocks.*.mlp.linear_fc1.bias", + "thinker.vision_model.decoder.layers.*.mlp.linear_fc2.weight": "thinker.visual.blocks.*.mlp.linear_fc2.weight", + "thinker.vision_model.decoder.layers.*.mlp.linear_fc2.bias": "thinker.visual.blocks.*.mlp.linear_fc2.bias", + # vision module norm + "thinker.vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "thinker.visual.blocks.*.norm1.weight", + "thinker.vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "thinker.visual.blocks.*.norm1.bias", + "thinker.vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "thinker.visual.blocks.*.norm2.weight", + "thinker.vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "thinker.visual.blocks.*.norm2.bias", + # # vision module deepstack merger + "thinker.vision_model.decoder.deepstack_merger_list.*.patch_norm.weight": "thinker.visual.merger_list.*.ln_q.weight", + "thinker.vision_model.decoder.deepstack_merger_list.*.patch_norm.bias": "thinker.visual.merger_list.*.ln_q.bias", + "thinker.vision_model.decoder.deepstack_merger_list.*.linear_fc1.weight": "thinker.visual.merger_list.*.mlp.0.weight", + "thinker.vision_model.decoder.deepstack_merger_list.*.linear_fc1.bias": "thinker.visual.merger_list.*.mlp.0.bias", + "thinker.vision_model.decoder.deepstack_merger_list.*.linear_fc2.weight": "thinker.visual.merger_list.*.mlp.2.weight", + "thinker.vision_model.decoder.deepstack_merger_list.*.linear_fc2.bias": "thinker.visual.merger_list.*.mlp.2.bias", + # vision module merger + "thinker.vision_model.merger.patch_norm.**": "thinker.visual.merger.ln_q.**", + "thinker.vision_model.merger.linear_fc1.weight": "thinker.visual.merger.mlp.0.weight", + "thinker.vision_model.merger.linear_fc1.bias": "thinker.visual.merger.mlp.0.bias", + "thinker.vision_model.merger.linear_fc2.weight": "thinker.visual.merger.mlp.2.weight", + "thinker.vision_model.merger.linear_fc2.bias": "thinker.visual.merger.mlp.2.bias", + } + + mapping_list = [] + + # Convert simple 1:1 mappings to AutoMapping objects + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter transformation + mapping_list.extend( + [ + # Audio model weights are replicated directly + ReplicatedMapping( + megatron_param="thinker.audio_model.**", + hf_param="thinker.audio_tower.**", + ), + # QKV mapping: Combine separate Q, K, V matrices + QKVMapping( + megatron_param="thinker.language_model.decoder.layers.*.self_attention.linear_qkv.weight", + q="thinker.model.layers.*.self_attn.q_proj.weight", + k="thinker.model.layers.*.self_attn.k_proj.weight", + v="thinker.model.layers.*.self_attn.v_proj.weight", + ), + # QKV bias mapping (if attention_bias is True) + QKVMapping( + megatron_param="thinker.language_model.decoder.layers.*.self_attention.linear_qkv.bias", + q="thinker.model.layers.*.self_attn.q_proj.bias", + k="thinker.model.layers.*.self_attn.k_proj.bias", + v="thinker.model.layers.*.self_attn.v_proj.bias", + ), + GatedMLPMapping( + megatron_param="thinker.language_model.decoder.layers.*.mlp.experts.linear_fc1.weight*", + gate="thinker.model.layers.*.mlp.experts.*.gate_proj.weight", + up="thinker.model.layers.*.mlp.experts.*.up_proj.weight", + ), + # QKV mapping for vision model + ConcatenatedQKVMapping( + megatron_param="thinker.vision_model.decoder.layers.*.self_attention.linear_qkv.weight", + hf_param="thinker.visual.blocks.*.attn.qkv.weight", + ), + ConcatenatedQKVMapping( + megatron_param="thinker.vision_model.decoder.layers.*.self_attention.linear_qkv.bias", + hf_param="thinker.visual.blocks.*.attn.qkv.bias", + ), + ReplicatedMapping( # These patch_embed are conv, we need to use ReplicatedMapping + megatron_param="thinker.vision_model.patch_embed.proj.**", + hf_param="thinker.visual.patch_embed.proj.**", + ), + ReplicatedMapping( + megatron_param="thinker.vision_model.pos_embed.weight", + hf_param="thinker.visual.pos_embed.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/src/megatron/bridge/models/qwen_omni/qwen3_omni_provider.py b/src/megatron/bridge/models/qwen_omni/qwen3_omni_provider.py new file mode 100644 index 0000000000..df1b21f332 --- /dev/null +++ b/src/megatron/bridge/models/qwen_omni/qwen3_omni_provider.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Qwen3 Omni MoE Model Provider configurations for Megatron-Core. + +This module provides configuration classes for Qwen3 Omni MoE (Mixture of Experts) multimodal models, +compatible with HuggingFace's Qwen3-Omni-MoE model configurations. +Reference: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking +""" + +from dataclasses import dataclass, field +from typing import List + +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeCode2WavConfig, + Qwen3OmniMoeTalkerConfig, + Qwen3OmniMoeTextConfig, + Qwen3OmniMoeThinkerConfig, +) + +from megatron.bridge.models import Qwen3MoEModelProvider +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.model import Qwen3OmniMoeModel + + +@dataclass +class Qwen3OmniMoeModelProvider(Qwen3MoEModelProvider): + """ + Base model provider for Qwen3 Omni Moe Models. + Inherits language model configuration from Qwen3MoeModelProvider. + + Key MoE Parameters (inherited from Qwen3MoEModelProvider): + - num_moe_experts: Number of total experts (default 128) + - moe_router_topk: Number of experts selected per token (default 8) + - moe_router_load_balancing_type: Load balancing strategy (default "aux_loss") + - moe_aux_loss_coeff: Auxiliary loss coefficient (default 1e-3) + - moe_grouped_gemm: Use grouped GEMM for efficiency (default True) + + Note: num_query_groups in parent class corresponds to num_key_value_heads in HF config. + """ + + thinker_config: Qwen3OmniMoeThinkerConfig = field(default_factory=lambda: Qwen3OmniMoeThinkerConfig()) + talker_config: Qwen3OmniMoeTalkerConfig | None = None + code2wav_config: Qwen3OmniMoeCode2WavConfig | None = None + hf_text_config: Qwen3OmniMoeTextConfig | None = None + + pretrained_model_name: str = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + + # Vision-specific token IDs matching Qwen3-Omni-MoE configuration + # Based on HuggingFace Qwen3-Omni-MoE configs + # Token ID for image placeholder in text + image_token_id: int = 151655 + # Token ID for video placeholder in text + video_token_id: int = 151656 + # Token ID for audio placeholder in text + audio_token_id: int = 151675 + # Token ID marking start of vision content + vision_start_token_id: int = 151652 + # Token ID marking end of vision content + vision_end_token_id: int = 151653 + # Token ID marking start of audio content + audio_start_token_id: int = 151669 + # Token ID marking end of audio content + audio_end_token_id: int = 151670 + # BOS token ID for Qwen3-Omni models + bos_token_id: int = 151643 + # EOS token ID for Qwen3-Omni models + eos_token_id: int = 151645 + tts_bos_token_id: int = 151672 + tts_eos_token_id: int = 151673 + tts_pad_token_id: int = 151671 + + head_dim: int = 128 + qk_layernorm: bool = True + attention_softmax_in_fp32: bool = True + attention_dropout: float = 0.0 + + # Override position embedding for multimodal rope + position_embedding_type: str = "mrope" + + apply_rotary_pos_emb_in_fp32: bool = False + + # Multimodal rope section for [temporal, height, width] dimensions + # Based on HuggingFace Qwen3-Omni config: mrope_section: [24, 20, 20] + mrope_section: List[int] = field(default_factory=lambda: [24, 20, 20]) + + # RoPE theta value specific to Qwen3-Omni models + # From HuggingFace config: rope_theta: 1000000 + rotary_base: float = 1000000 + spatial_merge_size: int = 2 + temporal_patch_size: int = 2 + patch_size: int = 16 + + # Override to disable scattering embeddings for vision insertion + scatter_embedding_sequence_parallel: bool = False + + # Router configuration + moe_router_pre_softmax: bool = False # Qwen3 specific + moe_router_dtype: str = "fp32" # Use FP32 for router computations + moe_router_score_function: str = "softmax" # Softmax scoring + moe_router_bias_update_rate: float = 0.001 # Router bias update rate + + # MoE optimization settings + moe_permute_fusion: bool = True # Fuse permutation operations + moe_token_dispatcher_type: str = "alltoall" # All-to-all communication + + # Dense layers configuration (some layers may not use MoE) + # Empty list means all layers use MoE, otherwise specify layer indices + mlp_only_layers: List[int] = field(default_factory=list) + + # Decoder sparse step (frequency of MoE layers) + decoder_sparse_step: int = 1 # Every layer is MoE by default + + # Freeze options for fine-tuning scenarios + # Whether to freeze language model weights + freeze_language_model: bool = False + # Whether to freeze vision encoder weights + freeze_vision_model: bool = False + # Whether to freeze vision-to-language projection weights + freeze_vision_projection: bool = False + # Whether ro freeze audio model weights + freeze_audio_model: bool = False + language_max_sequence_length: int = 2048 + + # QK layernorm is already True in Qwen3MoEModelProvider, no need to redefine + + # These are typically set in the base class but documented here for clarity + persist_layer_norm: bool = True # Persist layer norm for efficiency + bias_activation_fusion: bool = True # Fuse bias and activation + bias_dropout_fusion: bool = True # Fuse bias and dropout + masked_softmax_fusion: bool = False # Don't fuse masked softmax (Qwen specific) + deallocate_pipeline_outputs: bool = True # Deallocate pipeline outputs to save memory + async_tensor_model_parallel_allreduce: bool = True # Async tensor parallel + distribute_saved_activations: bool = False # Don't distribute saved activations + cp_comm_type: str = "p2p" # Point-to-point communication for context parallel + position_id_per_seconds: int = 13 + + use_hf_vision_model: bool = False + vision_dp_when_cp: bool = False + + def finalize(self) -> None: + if self.tensor_model_parallel_size > 1: + self.sequence_parallel = True + + super().finalize() + + def provide(self, pre_process=None, post_process=None, vp_stage=None): + """ + Provide a Qwen3 Omni MoE model instance with vision and language components. + """ + language_transformer_config = self + + # Create vision transformer config - placeholder for future use + # vision_transformer_config = deepcopy(self) + thinker_config = self.thinker_config + talker_config = self.talker_config + code2wav_config = self.code2wav_config + + language_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=self.num_moe_experts, + moe_grouped_gemm=True, + qk_layernorm=self.qk_layernorm, + fp8=False, + ) + + # reuse Qwen3OmniMoeAudioEncoder and Qwen3OmniMoeVisionEncoder for MoE model but replace the language model with MoE language model + model = Qwen3OmniMoeModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_transformer_layer_spec, + thinker_transformer_config=thinker_config, + talker_transformer_config=talker_config, + code2wav_transformer_config=code2wav_config, + pre_process=pre_process, + post_process=post_process, + pg_collection=self._pg_collection, + ) + + # Apply freeze options if any are enabled for fine-tuning + if ( + self.freeze_language_model + or self.freeze_vision_model + or self.freeze_vision_projection + or self.freeze_audio_model + ): + model.freeze( + freeze_language_model=self.freeze_language_model, + freeze_vision_model=self.freeze_vision_model, + freeze_vision_projection=self.freeze_vision_projection, + freeze_audio_model=self.freeze_audio_model, + ) + + return model + + def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """ + Provide just the language MoE model component without vision. + + Args: + pre_process: Whether this is the first stage in pipeline parallelism + post_process: Whether this is the last stage in pipeline parallelism + vp_stage: Virtual pipeline stage number + + Returns: + MCoreGPTModel instance (MoE language model only) + """ + # Use parent class to create standard MoE language model + return super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) diff --git a/src/megatron/bridge/recipes/qwen_omni/__init__.py b/src/megatron/bridge/recipes/qwen_omni/__init__.py new file mode 100644 index 0000000000..12a7cbc9e4 --- /dev/null +++ b/src/megatron/bridge/recipes/qwen_omni/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Qwen3 models +from .qwen3_omni_moe import ( + qwen3_omni_moe_30b_a3b_finetune_config, + qwen3_omni_moe_30b_a3b_pretrain_config, +) + + +__all__ = [ + # Qwen3-Omni-Moe pretrain configs + "qwen3_omni_moe_30b_a3b_pretrain_config", + # Qwen3-Omni-Moe finetune configs + "qwen3_omni_moe_30b_a3b_finetune_config", +] diff --git a/src/megatron/bridge/recipes/qwen_omni/qwen3_omni_moe.py b/src/megatron/bridge/recipes/qwen_omni/qwen3_omni_moe.py new file mode 100644 index 0000000000..bb998cb9f7 --- /dev/null +++ b/src/megatron/bridge/recipes/qwen_omni/qwen3_omni_moe.py @@ -0,0 +1,437 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Optional, Union + +import torch +from transformers import AutoTokenizer, Qwen2VLImageProcessor +from typing_extensions import TypedDict, Unpack + +from megatron.bridge import AutoBridge +from megatron.bridge.data.vlm_datasets import ( + EnergonProvider, + HFDatasetConversationProvider, + MockVLMConversationProvider, + PreloadedVLMConversationProvider, +) +from megatron.bridge.peft.base import PEFT +from megatron.bridge.recipes.qwen_vl.data.energon.task_encoder import QwenVLTaskEncoder +from megatron.bridge.recipes.utils.finetune_utils import default_peft_config as _default_peft_config +from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing +from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE +from megatron.bridge.training.comm_overlap import CommOverlapConfig +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DatasetProvider, + DistributedDataParallelConfig, + LoggerConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, + ValidationConfig, +) +from megatron.bridge.training.flex_dispatcher_backend import apply_flex_dispatcher_backend +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, bf16_mixed + + +class Qwen3OmniCommonKwargs(TypedDict, total=False): + """Typed options accepted by Qwen3 Omni MoE recipe helpers.""" + + # Core identifiers + hf_path: str + dir: Optional[str] + name: str + # Dataset configuration + data_paths: Optional[List[str]] + data_args_path: Optional[str] + train_data_path: Optional[List[str]] + valid_data_path: Optional[List[str]] + test_data_path: Optional[List[str]] + per_split_data_args_path: Optional[str] + mock: bool + # Model configuration + tensor_model_parallel_size: int + pipeline_model_parallel_size: int + pipeline_dtype: Optional[torch.dtype] + virtual_pipeline_model_parallel_size: Optional[int] + context_parallel_size: int + expert_model_parallel_size: Optional[int] + expert_tensor_parallel_size: int + sequence_parallel: bool + use_megatron_fsdp: bool + enable_recompute: bool + account_for_embedding_in_pipeline_split: bool + account_for_loss_in_pipeline_split: bool + # Training hyperparameters + train_iters: int + global_batch_size: int + micro_batch_size: int + seq_length: int + lr: float + min_lr: float + lr_warmup_iters: int + lr_decay_iters: Optional[int] + eval_interval: int + save_interval: int + use_null_tokenizer: bool + # Precision / overlap configs + precision_config: Optional[Union[MixedPrecisionConfig, str]] + comm_overlap_config: Optional[CommOverlapConfig] + moe_flex_dispatcher_backend: str | None + # Freeze options + pretrained_checkpoint: Optional[str] + freeze_language_model: bool + freeze_vision_model: bool + freeze_vision_projection: bool + freeze_audio_model: bool + # Dataset configuration + dataset_type: Optional[str] + image_folder: Optional[str] + tokenizer_model: Optional[str] + # PEFT options + peft: Optional[Union[str, PEFT]] + finetune_lr: float + + +def qwen3_omni_moe_30b_a3b_pretrain_config(**user_kwargs: Unpack[Qwen3OmniCommonKwargs]) -> ConfigContainer: + """Return a pre-training config for Qwen3-Omni-30B-A3B-Instruct. + + See `_qwen3_omni_common` for the full list of parameters. + """ + recommended_kwargs: Qwen3OmniCommonKwargs = { + "hf_path": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "pipeline_dtype": torch.bfloat16, + "expert_model_parallel_size": 8, + "freeze_language_model": False, + "freeze_vision_model": False, + "freeze_vision_projection": False, + "freeze_audio_model": False, + } + # Combine defaults with user kwargs; user values take precedence. + combined_kwargs: Qwen3OmniCommonKwargs = {**recommended_kwargs, **user_kwargs} + return _qwen3_omni_common(**combined_kwargs) + + +def qwen3_omni_moe_30b_a3b_finetune_config(**user_kwargs: Unpack[Qwen3OmniCommonKwargs]) -> ConfigContainer: + """Return a fine-tuning config for Qwen3-Omni-30B-A3B-Instruct. + + This is a Mixture-of-Experts model with 128 experts and top-8 routing. + Recommended to use with expert parallelism (EP) for efficient training. + + See `_qwen3_Omni_common` for the full list of parameters. + """ + + recommended_kwargs: Qwen3OmniCommonKwargs = { + "hf_path": "../hf-hub/Qwen/Qwen3-Omni-30B-A3B-Instruct", + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "pipeline_dtype": torch.bfloat16, + "expert_model_parallel_size": 8, + "finetune_lr": 2e-5, + "freeze_language_model": True, + "freeze_vision_model": True, + "freeze_vision_projection": False, + "freeze_audio_model": True, + "min_lr": 2e-6, + "lr": 2e-5, + "lr_warmup_iters": 200, + "micro_batch_size": 1, + "global_batch_size": 32, + } + # Combine defaults with user kwargs; user values take precedence. + combined_kwargs: Qwen3OmniCommonKwargs = {**recommended_kwargs, **user_kwargs} + return _qwen3_omni_common(**combined_kwargs) + + +def _qwen3_omni_common( + hf_path: str, + dir: Optional[str] = None, + name: str = "default", + # Dataset configuration + data_paths: Optional[List[str]] = None, + data_args_path: Optional[str] = None, + train_data_path: Optional[List[str]] = None, + valid_data_path: Optional[List[str]] = None, + test_data_path: Optional[List[str]] = None, + per_split_data_args_path: Optional[str] = None, + mock: bool = False, + # Model configuration + tensor_model_parallel_size: int = 4, + pipeline_model_parallel_size: int = 2, + pipeline_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_model_parallel_size: Optional[int] = None, + context_parallel_size: int = 1, + expert_model_parallel_size: Optional[int] = 4, + expert_tensor_parallel_size: int = 1, + sequence_parallel: bool = False, + use_megatron_fsdp: bool = False, + enable_recompute: bool = False, + account_for_embedding_in_pipeline_split: bool = False, + account_for_loss_in_pipeline_split: bool = False, + # Training hyperparameters + train_iters: int = 300000, + global_batch_size: int = 32, + micro_batch_size: int = 2, + seq_length: int = 4096, + lr: float = 3e-4, + min_lr: float = 3e-5, + lr_warmup_iters: int = 500, + lr_decay_iters: Optional[int] = None, + eval_interval: int = 500, + save_interval: int = 500, + use_null_tokenizer: bool = False, + # Precision recipe + precision_config: Optional[Union[MixedPrecisionConfig, str]] = None, + comm_overlap_config: Optional[CommOverlapConfig] = None, + moe_flex_dispatcher_backend: Optional[str] = None, + # Freeze options + pretrained_checkpoint: Optional[str] = None, + freeze_language_model: bool = True, + freeze_vision_model: bool = True, + freeze_vision_projection: bool = False, + freeze_audio_model: bool = True, + # Dataset configuration + dataset_type: Optional[str] = None, + image_folder: Optional[str] = None, + tokenizer_model: Optional[str] = None, + # PEFT options + peft: Optional[Union[str, PEFT]] = None, + finetune_lr: Optional[float] = None, +) -> ConfigContainer: + """ + Create a pre-training configuration for Qwen3 Omni MoE models using a given HuggingFace path. + + Args: + hf_path (str): HuggingFace model path (e.g., "Qwen/Qwen3-30B-A3B", "Qwen/Qwen3-235B-A22B"). + dir (Optional[str]): Base directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + data_paths (Optional[List[str]]): List of paths to dataset files. If None, mock data will be used. + data_args_path (Optional[str]): Path to file containing data arguments. + train_data_path (Optional[List[str]]): List of training data paths. + valid_data_path (Optional[List[str]]): List of validation data paths. + test_data_path (Optional[List[str]]): List of test data paths. + per_split_data_args_path (Optional[str]): Path to JSON file with per-split data configuration. + mock (bool): Whether to use mock data. If True, ignores data_paths. + tensor_model_parallel_size (int): Degree of tensor model parallelism. + pipeline_model_parallel_size (int): Degree of pipeline model parallelism. + pipeline_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_model_parallel_size (Optional[int]): Size of virtual pipeline parallelism. + context_parallel_size (int): Degree of context parallelism to be passed to model_config. + expert_model_parallel_size (Optional[int]): Degree of expert parallelism for MoE. + expert_tensor_parallel_size (int): Expert tensor parallelism for MoE. + sequence_parallel (bool): Whether to use sequence parallelism. + use_megatron_fsdp (bool): Whether to use Megatron FSDP. + enable_recompute (bool): Whether to enable recompute for memory optimization. + account_for_embedding_in_pipeline_split (bool): Whether to account for embedding in pipeline split. + account_for_loss_in_pipeline_split (bool): Whether to account for loss in pipeline split. + train_iters (int): Total number of training iterations. + global_batch_size (int): Global batch size for training. + micro_batch_size (int): Micro batch size for training. + seq_length (int): Sequence length for training data. + lr (float): Learning rate. + min_lr (float): Minimum learning rate for cosine decay. + lr_warmup_iters (int): Number of warmup iterations for the learning rate. + lr_decay_iters (Optional[int]): Number of iterations over which to decay the LR. + precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration for the model. + comm_overlap_config (Optional[CommOverlapConfig]): Communication overlap configuration. + moe_flex_dispatcher_backend (str | None): Token dispatcher type [deepep, hybridep]. + pretrained_checkpoint (Optional[str]): Path to pretrained checkpoint. + freeze_language_model (bool): Whether to freeze the language model. + freeze_vision_model (bool): Whether to freeze the vision model. + freeze_vision_projection (bool): Whether to freeze the vision projection. + freeze_audio_model (bool): Whether to freeze the audio model. + dataset_type (Optional[str]): Type of dataset to use. + image_folder (Optional[str]): Path to image folder. + tokenizer_model (Optional[str]): Path to tokenizer model. + peft (Optional[Union[str, PEFT]]): PEFT configuration (e.g., "lora", "dora", or PEFT object). + finetune_lr (Optional[float]): Learning rate override for fine-tuning. + Returns: + ConfigContainer: Configuration for pre-training. + """ + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + bridge = AutoBridge.from_hf_pretrained(hf_path) + model_cfg = bridge.to_megatron_provider(load_weights=False) + model_cfg.tensor_model_parallel_size = tensor_model_parallel_size + model_cfg.pipeline_model_parallel_size = pipeline_model_parallel_size + model_cfg.pipeline_dtype = pipeline_dtype + model_cfg.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size + model_cfg.context_parallel_size = context_parallel_size + model_cfg.expert_model_parallel_size = expert_model_parallel_size + model_cfg.expert_tensor_parallel_size = expert_tensor_parallel_size + model_cfg.sequence_parallel = sequence_parallel + # Freeze options + model_cfg.freeze_language_model = freeze_language_model + model_cfg.freeze_vision_model = freeze_vision_model + model_cfg.freeze_vision_projection = freeze_vision_projection + model_cfg.freeze_audio_model = freeze_audio_model + + apply_flex_dispatcher_backend(model_cfg, moe_flex_dispatcher_backend) + + if precision_config is None: + precision_config = bf16_mixed() + + # MoE-specific pipeline split configurations + if account_for_embedding_in_pipeline_split: + model_cfg.account_for_embedding_in_pipeline_split = True + if account_for_loss_in_pipeline_split: + model_cfg.account_for_loss_in_pipeline_split = True + + # Add recompute settings for memory optimization (used by some MoE models) + if enable_recompute: + model_cfg.recompute_granularity = "full" + model_cfg.recompute_method = "uniform" + model_cfg.recompute_num_layers = 1 + model_cfg.seq_length = seq_length + model_cfg.cross_entropy_fusion_impl = "te" + + # Optimizer and scheduler - use finetune_lr if provided, otherwise use lr + effective_lr = finetune_lr if finetune_lr is not None else lr + opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=lr_decay_iters if lr_decay_iters is not None else train_iters, + max_lr=effective_lr, + min_lr=min_lr, + ) + + # PEFT config + peft_config = _default_peft_config(peft) + + # Determine dataset selection strategy. + _processor_model = tokenizer_model or hf_path + _dataset_choice = dataset_type or ("mock" if mock else "hf") + + if _dataset_choice == "mock": + dataset_cfg: DatasetProvider = MockVLMConversationProvider( + seq_length=seq_length, + hf_processor_path=_processor_model, + prompt="Describe this image.", + num_workers=1, + dataloader_type="single", + data_sharding=True, + pin_memory=True, + persistent_workers=False, + create_attention_mask=True, + pad_to_max_length=True, + ) + elif _dataset_choice == "preloaded": + dataset_cfg = PreloadedVLMConversationProvider( + seq_length=seq_length, + hf_processor_path=_processor_model, + train_data_path=train_data_path[0] if isinstance(train_data_path, list) else train_data_path, + valid_data_path=valid_data_path[0] if isinstance(valid_data_path, list) else valid_data_path, + test_data_path=test_data_path[0] if isinstance(test_data_path, list) else test_data_path, + image_folder=image_folder, + num_workers=2, + dataloader_type="single", + data_sharding=True, + pin_memory=True, + persistent_workers=False, + ) + elif _dataset_choice == "hf": + dataset_cfg = HFDatasetConversationProvider( + seq_length=seq_length, + hf_processor_path=_processor_model, + maker_name="make_cord_v2_dataset", + num_workers=2, + dataloader_type="single", + data_sharding=True, + pin_memory=True, + persistent_workers=False, + ) + elif _dataset_choice == "energon": + tokenizer = AutoTokenizer.from_pretrained(_processor_model) + # Use from_pretrained to ensure correct normalization (mean/std) and config (min_pixels) + # matching Preloaded provider behavior. + image_processor = Qwen2VLImageProcessor.from_pretrained(_processor_model) + + dataset_cfg = EnergonProvider( + seq_length=seq_length, + path=train_data_path[0] if isinstance(train_data_path, list) else train_data_path, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=2, + dataloader_type="external", + task_encoder=QwenVLTaskEncoder( + tokenizer=tokenizer, + image_processor=image_processor, + max_padding_length=seq_length, + min_pixels=200704, + max_pixels=1003520, + ), + ) + else: + raise ValueError( + f"Unsupported dataset_type '{_dataset_choice}'. Expected one of ['mock', 'preloaded', 'hf', 'energon']." + ) + # Config Container + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + manual_gc=True, + manual_gc_interval=100, + manual_gc_eval=100, + ), + validation=ValidationConfig( + eval_interval=eval_interval, + eval_iters=32, + ), + optimizer=opt_config, + scheduler=scheduler, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, # qwen3_vl does not support overlap_grad_reduce=True in current implementation + overlap_param_gather=False, # qwen3_vl does not support overlap_param_gather=True in current implementation + average_in_collective=True, # Not supported for Megatron FSDP for now, need to be set to False if using Megatron FSDP + data_parallel_sharding_strategy="optim_grads_params", # For Megatron FSDP only + use_distributed_optimizer=True, + use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True + ), + dataset=dataset_cfg, + logger=LoggerConfig( + log_interval=10, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer" if use_null_tokenizer else "HuggingFaceTokenizer", + tokenizer_model=hf_path if not use_null_tokenizer else None, + vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE if use_null_tokenizer else None, + ), + checkpoint=CheckpointConfig( + pretrained_checkpoint=pretrained_checkpoint, + save_interval=save_interval, + save=checkpoint_dir, + load=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + ), + rng=RNGConfig(seed=1234), + peft=peft_config, + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg diff --git a/tests/functional_tests/models/qwen_omni/test_qwen3_omni_moe_conversion.py b/tests/functional_tests/models/qwen_omni/test_qwen3_omni_moe_conversion.py new file mode 100644 index 0000000000..d4f5d27e84 --- /dev/null +++ b/tests/functional_tests/models/qwen_omni/test_qwen3_omni_moe_conversion.py @@ -0,0 +1,579 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functional tests for Qwen3 Omni Moe HF to Megatron generation. + +Example run commands: + # Run the generation test + pytest tests/functional_tests/models/qwen_omni/test_qwen3_omni_moe_generation.py + +Note: This test use small proxy/toy models for fast generation testing. +""" + +import json +import subprocess +from pathlib import Path + +import pytest +import torch +from transformers import ( + AutoTokenizer, + Qwen3OmniMoeConfig, + Qwen3OmniMoeForConditionalGeneration, +) +from transformers.models.qwen3_omni_moe import Qwen3OmniMoeConfig + + +HF_QWEN3_OMNI_MOE_TOY_MODEL_CONFIG = { + "architectures": ["Qwen3OmniMoeForConditionalGeneration"], + "assistant_token_id": 77091, + "code2wav_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "codebook_dim": 512, + "codebook_size": 2048, + "decoder_dim": 1536, + "hidden_act": "silu", + "hidden_size": 256, + "intermediate_size": 1024, + "layer_scale_initial_scale": 0.01, + "max_position_embeddings": 2000, + "model_type": "", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 2, + "num_quantizers": 4, + "num_semantic_quantizers": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 10000, + "semantic_codebook_size": 4096, + "sliding_window": 18, + "upsample_rates": [8, 5, 4, 3], + "upsampling_ratios": [2, 2], + "vector_quantization_hidden_dimension": 128, + }, + "dtype": "bfloat16", + "enable_audio_output": True, + "im_end_token_id": 151645, + "im_start_token_id": 151644, + "model_type": "qwen3_omni_moe", + "system_token_id": 8948, + "talker_config": { + "text_config": { + "attention_bias": False, + "attention_dropout": 0, + "decoder_sparse_step": 1, + "head_dim": 16, + "hidden_act": "silu", + "hidden_size": 256, + "initializer_range": 0.02, + "intermediate_size": 512, + "max_position_embeddings": 16384, + "mlp_only_layers": [], + "moe_intermediate_size": 128, + "norm_topk_prob": True, + "num_attention_heads": 4, + "num_experts": 4, + "num_experts_per_tok": 2, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "interleaved": True, + "mrope_section": [16, 24, 24], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "shared_expert_intermediate_size": 256, + "sliding_window": None, + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 3072, + }, + "accept_hidden_layer": 6, + "audio_end_token_id": 151670, + "audio_start_token_id": 151669, + "audio_token_id": 151675, + "code_predictor_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_bias": False, + "attention_dropout": 0, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": True, + "dtype": None, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "head_dim": 16, + "hidden_act": "silu", + "hidden_size": 256, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_range": 0.02, + "intermediate_size": 1024, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_types": ["full_attention", "full_attention", "full_attention", "full_attention", "full_attention"], + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 2048, + "max_window_layers": 28, + "min_length": 0, + "model_type": "qwen3_omni_moe_talker_code_predictor", + "no_repeat_ngram_size": 0, + "num_attention_heads": 4, + "num_beam_groups": 1, + "num_beams": 1, + "num_code_groups": 16, + "num_hidden_layers": 5, + "num_key_value_heads": 2, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sep_token_id": None, + "sliding_window": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": False, + "tokenizer_class": None, + "top_k": 10, + "top_p": 1.0, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 2048, + }, + "codec_bos_id": 2149, + "codec_eos_token_id": 2150, + "codec_nothink_id": 2155, + "codec_pad_id": 2148, + "codec_think_bos_id": 2156, + "codec_think_eos_id": 2157, + "image_token_id": 151655, + "model_type": "qwen3_omni_moe_talker", + "num_code_groups": 16, + "output_router_logits": False, + "position_id_per_seconds": 13, + "seconds_per_chunk": 2, + "spatial_merge_size": 2, + "speaker_id": {"chelsie": 2301, "ethan": 2302, "aiden": 2303}, + "thinker_hidden_size": 512, + "video_token_id": 151656, + "vision_start_token_id": 151652, + }, + "thinker_config": { + "audio_config": { + "_name_or_path": "", + "activation_dropout": 0, + "activation_function": "gelu", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "conv_chunksize": 500, + "cross_attention_hidden_size": None, + "d_model": 1280, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": True, + "downsample_hidden_size": 480, + "dropout": 0, + "dtype": None, + "early_stopping": False, + "encoder_attention_heads": 20, + "encoder_ffn_dim": 5120, + "encoder_layers": 32, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_range": 0.02, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "length_penalty": 1.0, + "max_length": 20, + "max_source_positions": 1500, + "min_length": 0, + "model_type": "qwen3_omni_moe_audio_encoder", + "n_window": 50, + "n_window_infer": 800, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 32, + "num_mel_bins": 128, + "num_return_sequences": 1, + "output_attentions": False, + "output_dim": 2048, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "scale_embedding": False, + "sep_token_id": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 10, + "top_p": 1.0, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + }, + "audio_end_token_id": 151670, + "audio_start_token_id": 151669, + "audio_token_id": 151675, + "dtype": "bfloat16", + "image_token_id": 151655, + "initializer_range": 0.02, + "model_type": "qwen3_omni_moe_thinker", + "position_id_per_seconds": 13, + "seconds_per_chunk": 2, + "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_bias": False, + "attention_dropout": 0.0, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_sparse_step": 1, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": True, + "dtype": None, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 512, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_range": 0.02, + "intermediate_size": 384, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 16384, + "min_length": 0, + "mlp_only_layers": [], + "model_type": "qwen3_omni_moe_text", + "moe_intermediate_size": 384, + "no_repeat_ngram_size": 0, + "norm_topk_prob": True, + "num_attention_heads": 8, + "num_beam_groups": 1, + "num_beams": 1, + "num_experts": 4, + "num_experts_per_tok": 2, + "num_hidden_layers": 8, + "num_key_value_heads": 2, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_router_logits": False, + "output_scores": False, + "pad_token_id": None, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "interleaved": True, + "mrope_interleaved": True, + "mrope_section": [16, 24, 24], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "sep_token_id": None, + "shared_expert_intermediate_size": 0, + "sliding_window": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": False, + "tokenizer_class": None, + "top_k": 10, + "top_p": 1.0, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "use_cache": True, + "use_qk_norm": True, + "use_sliding_window": False, + "vocab_size": 152064, + }, + "user_token_id": 872, + "video_token_id": 151656, + "vision_config": { + "_name_or_path": "", + "add_cross_attention": False, + "apply_vit_abs_pos_embed": True, + "architectures": None, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "deepstack_visual_indexes": [1, 2, 3], + "depth": 27, + "diversity_penalty": 0.0, + "do_sample": True, + "dtype": None, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 288, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "image_size": 384, + "in_channels": 3, + "in_chans": 3, + "initializer_range": 0.02, + "intermediate_size": 4304, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "qwen3_omni_moe_vision_encoder", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_heads": 2, + "num_return_sequences": 1, + "out_hidden_size": 2048, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "patch_size": 16, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "spatial_merge_size": 2, + "spatial_patch_size": 16, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "temporal_patch_size": 2, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "tokens_per_second": 2, + "top_k": 10, + "top_p": 1.0, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + }, + "transformers_version": "4.57.0.dev0", + "tts_bos_token_id": 151672, + "tts_eos_token_id": 151673, + "tts_pad_token_id": 151671, + "user_token_id": 872, +} + + +class TestQwen3OmniMoeConversion: + """ + Test Qwen3 Omni Moe model Conversion + """ + + @pytest.fixture(scope="class") + def qwen3_omni_moe_toy_model_path(self, tmp_path_factory): + """ + Create and save a HuggingFace Qwen3 Omni MoE toy model to a temporary directory. + + Args: + tmp_path_factory: Pytest temporary path factory for class-scoped fixtures + + Returns: + str: Path to the saved HuggingFace MoE model directory + """ + # Create a temporary directory for this test class + temp_dir = tmp_path_factory.mktemp("qwen3_omni_moe_generation_toy_model") + model_dir = temp_dir / "qwen3_omni_moe_toy" + + # Create Qwen3 VL MoE config from the toy model config + config = Qwen3OmniMoeConfig(**HF_QWEN3_OMNI_MOE_TOY_MODEL_CONFIG) + config.torch_dtype = torch.bfloat16 + + # Set rope_scaling on text_config + if hasattr(config.thinker_config, "text_config") and config.thinker_config.text_config is not None: + config.thinker_config.text_config.rope_scaling = {"type": "mrope", "mrope_section": [16, 24, 24]} + + # Create model with random weights and convert to bfloat16 + model = Qwen3OmniMoeForConditionalGeneration(config) + model = model.to(dtype=torch.bfloat16) + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Omni-30B-A3B-Instruct") + tokenizer.save_pretrained(model_dir) + + model.save_pretrained(model_dir, safe_serialization=True) + config_path = model_dir / "config.json" + with open(config_path, "w") as f: + json.dump(HF_QWEN3_OMNI_MOE_TOY_MODEL_CONFIG, f, indent=2) + + print(f"Created MoE toy model at: {model_dir}") + return str(model_dir) + + def test_moe_toy_model_creation(self, qwen3_omni_moe_toy_model_path): + """Test MoE toy model creation.""" + model_path = Path(qwen3_omni_moe_toy_model_path) + assert model_path.exists() + config_file = model_path / "config.json" + assert config_file.exists() + with open(config_file) as f: + config_data = json.load(f) + + assert config_data["model_type"] == "qwen3_omni_moe" + + _ = Qwen3OmniMoeForConditionalGeneration.from_pretrained( + qwen3_omni_moe_toy_model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=False, + ) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize("tp,pp", [(2, 1)]) + def test_moe_conversion(self, qwen3_omni_moe_toy_model_path, tmp_path, tp, pp): + """Test MoE model conversion.""" + test_output_dir = tmp_path / "qwen3_omni_moe_test" + test_output_dir.mkdir(exist_ok=True) + + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "--nnodes=1", + "-m", + "coverage", + "run", + "--data-file=/opt/Megatron-Bridge/.coverage", + "--source=/opt/Megatron-Bridge/", + "--parallel-mode", + "examples/conversion/hf_megatron_roundtrip_multi_gpu.py", + "--hf-model-id", + qwen3_omni_moe_toy_model_path, + "--output-dir", + str(test_output_dir), + "--tp", + str(tp), + "--pp", + str(pp), + ] + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent.parent + ) + + if result.returncode != 0: + print(f"STDOUT: {result.stdout}") + print(f"STDERR: {result.stderr}") + assert False, f"MoE conversion failed with return code {result.returncode}" + + model_name = Path(qwen3_omni_moe_toy_model_path).name + converted_model_dir = test_output_dir / model_name + assert converted_model_dir.exists() + + config_file = converted_model_dir / "config.json" + assert config_file.exists() + + with open(config_file) as f: + saved_config = json.load(f) + + assert saved_config["model_type"] == "qwen3_omni_moe" diff --git a/tests/functional_tests/models/qwen_omni/test_qwen3_omni_moe_generation.py b/tests/functional_tests/models/qwen_omni/test_qwen3_omni_moe_generation.py new file mode 100644 index 0000000000..90c556530b --- /dev/null +++ b/tests/functional_tests/models/qwen_omni/test_qwen3_omni_moe_generation.py @@ -0,0 +1,567 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functional tests for Qwen3 Omni Moe HF to Megatron generation. + +Example run commands: + # Run the generation test + pytest tests/functional_tests/models/qwen_omni/test_qwen3_omni_moe_generation.py + +Note: This test use small proxy/toy models for fast generation testing. +""" + +import json +import subprocess +from pathlib import Path + +import pytest +import torch +from transformers import ( + AutoTokenizer, + Qwen3OmniMoeConfig, + Qwen3OmniMoeForConditionalGeneration, +) +from transformers.models.qwen3_omni_moe import Qwen3OmniMoeConfig + + +HF_QWEN3_OMNI_MOE_TOY_MODEL_CONFIG = { + "architectures": ["Qwen3OmniMoeForConditionalGeneration"], + "assistant_token_id": 77091, + "code2wav_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "codebook_dim": 512, + "codebook_size": 2048, + "decoder_dim": 1536, + "hidden_act": "silu", + "hidden_size": 256, + "intermediate_size": 1024, + "layer_scale_initial_scale": 0.01, + "max_position_embeddings": 2000, + "model_type": "", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 2, + "num_quantizers": 4, + "num_semantic_quantizers": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 10000, + "semantic_codebook_size": 4096, + "sliding_window": 18, + "upsample_rates": [8, 5, 4, 3], + "upsampling_ratios": [2, 2], + "vector_quantization_hidden_dimension": 128, + }, + "dtype": "bfloat16", + "enable_audio_output": True, + "im_end_token_id": 151645, + "im_start_token_id": 151644, + "model_type": "qwen3_omni_moe", + "system_token_id": 8948, + "talker_config": { + "text_config": { + "attention_bias": False, + "attention_dropout": 0, + "decoder_sparse_step": 1, + "head_dim": 16, + "hidden_act": "silu", + "hidden_size": 256, + "initializer_range": 0.02, + "intermediate_size": 512, + "max_position_embeddings": 16384, + "mlp_only_layers": [], + "moe_intermediate_size": 128, + "norm_topk_prob": True, + "num_attention_heads": 4, + "num_experts": 4, + "num_experts_per_tok": 2, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "interleaved": True, + "mrope_section": [16, 24, 24], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "shared_expert_intermediate_size": 256, + "sliding_window": None, + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 3072, + }, + "accept_hidden_layer": 6, + "audio_end_token_id": 151670, + "audio_start_token_id": 151669, + "audio_token_id": 151675, + "code_predictor_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_bias": False, + "attention_dropout": 0, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": True, + "dtype": None, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "head_dim": 16, + "hidden_act": "silu", + "hidden_size": 256, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_range": 0.02, + "intermediate_size": 1024, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_types": ["full_attention", "full_attention", "full_attention", "full_attention", "full_attention"], + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 2048, + "max_window_layers": 28, + "min_length": 0, + "model_type": "qwen3_omni_moe_talker_code_predictor", + "no_repeat_ngram_size": 0, + "num_attention_heads": 4, + "num_beam_groups": 1, + "num_beams": 1, + "num_code_groups": 16, + "num_hidden_layers": 5, + "num_key_value_heads": 2, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sep_token_id": None, + "sliding_window": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": False, + "tokenizer_class": None, + "top_k": 10, + "top_p": 1.0, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 2048, + }, + "codec_bos_id": 2149, + "codec_eos_token_id": 2150, + "codec_nothink_id": 2155, + "codec_pad_id": 2148, + "codec_think_bos_id": 2156, + "codec_think_eos_id": 2157, + "image_token_id": 151655, + "model_type": "qwen3_omni_moe_talker", + "num_code_groups": 16, + "output_router_logits": False, + "position_id_per_seconds": 13, + "seconds_per_chunk": 2, + "spatial_merge_size": 2, + "speaker_id": {"chelsie": 2301, "ethan": 2302, "aiden": 2303}, + "thinker_hidden_size": 512, + "video_token_id": 151656, + "vision_start_token_id": 151652, + }, + "thinker_config": { + "audio_config": { + "_name_or_path": "", + "activation_dropout": 0, + "activation_function": "gelu", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "conv_chunksize": 500, + "cross_attention_hidden_size": None, + "d_model": 1280, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": True, + "downsample_hidden_size": 480, + "dropout": 0, + "dtype": None, + "early_stopping": False, + "encoder_attention_heads": 20, + "encoder_ffn_dim": 5120, + "encoder_layers": 32, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_range": 0.02, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "length_penalty": 1.0, + "max_length": 20, + "max_source_positions": 1500, + "min_length": 0, + "model_type": "qwen3_omni_moe_audio_encoder", + "n_window": 50, + "n_window_infer": 800, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 32, + "num_mel_bins": 128, + "num_return_sequences": 1, + "output_attentions": False, + "output_dim": 2048, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "scale_embedding": False, + "sep_token_id": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 10, + "top_p": 1.0, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + }, + "audio_end_token_id": 151670, + "audio_start_token_id": 151669, + "audio_token_id": 151675, + "dtype": "bfloat16", + "image_token_id": 151655, + "initializer_range": 0.02, + "model_type": "qwen3_omni_moe_thinker", + "position_id_per_seconds": 13, + "seconds_per_chunk": 2, + "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_bias": False, + "attention_dropout": 0.0, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_sparse_step": 1, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": True, + "dtype": None, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 512, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_range": 0.02, + "intermediate_size": 384, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 16384, + "min_length": 0, + "mlp_only_layers": [], + "model_type": "qwen3_omni_moe_text", + "moe_intermediate_size": 384, + "no_repeat_ngram_size": 0, + "norm_topk_prob": True, + "num_attention_heads": 8, + "num_beam_groups": 1, + "num_beams": 1, + "num_experts": 4, + "num_experts_per_tok": 2, + "num_hidden_layers": 8, + "num_key_value_heads": 2, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_router_logits": False, + "output_scores": False, + "pad_token_id": None, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "interleaved": True, + "mrope_interleaved": True, + "mrope_section": [16, 24, 24], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "sep_token_id": None, + "shared_expert_intermediate_size": 0, + "sliding_window": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": False, + "tokenizer_class": None, + "top_k": 10, + "top_p": 1.0, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "use_cache": True, + "use_qk_norm": True, + "use_sliding_window": False, + "vocab_size": 152064, + }, + "user_token_id": 872, + "video_token_id": 151656, + "vision_config": { + "_name_or_path": "", + "add_cross_attention": False, + "apply_vit_abs_pos_embed": True, + "architectures": None, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "deepstack_visual_indexes": [1, 2, 3], + "depth": 27, + "diversity_penalty": 0.0, + "do_sample": True, + "dtype": None, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 288, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "image_size": 384, + "in_channels": 3, + "in_chans": 3, + "initializer_range": 0.02, + "intermediate_size": 4304, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "qwen3_omni_moe_vision_encoder", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_heads": 2, + "num_return_sequences": 1, + "out_hidden_size": 2048, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "patch_size": 16, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "spatial_merge_size": 2, + "spatial_patch_size": 16, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "temporal_patch_size": 2, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "tokens_per_second": 2, + "top_k": 10, + "top_p": 1.0, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + }, + "transformers_version": "4.57.0.dev0", + "tts_bos_token_id": 151672, + "tts_eos_token_id": 151673, + "tts_pad_token_id": 151671, + "user_token_id": 872, +} + + +class TestQwen3OmniMoeGeneration: + """ + Test Qwen3 Omni Moe model generation using HF to Megatron conversion with vision inputs. + Uses small proxy/toy models for fast generation testing. + """ + + @pytest.fixture(scope="class") + def qwen3_omni_moe_toy_model_path(self, tmp_path_factory): + """ + Create and save a HuggingFace Qwen3 Omni MoE toy model to a temporary directory. + + Args: + tmp_path_factory: Pytest temporary path factory for class-scoped fixtures + + Returns: + str: Path to the saved HuggingFace MoE model directory + """ + # Create a temporary directory for this test class + temp_dir = tmp_path_factory.mktemp("qwen3_omni_moe_generation_toy_model") + model_dir = temp_dir / "qwen3_omni_moe_toy" + + # Create Qwen3 VL MoE config from the toy model config + config = Qwen3OmniMoeConfig(**HF_QWEN3_OMNI_MOE_TOY_MODEL_CONFIG) + config.torch_dtype = torch.bfloat16 + + # Set rope_scaling on text_config + if hasattr(config.thinker_config, "text_config") and config.thinker_config.text_config is not None: + config.thinker_config.text_config.rope_scaling = {"type": "mrope", "mrope_section": [16, 24, 24]} + + # Create model with random weights and convert to bfloat16 + model = Qwen3OmniMoeForConditionalGeneration(config) + model = model.to(dtype=torch.bfloat16) + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Omni-30B-A3B-Instruct") + tokenizer.save_pretrained(model_dir) + + # Also save the image processor + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained("Qwen/Qwen3-Omni-30B-A3B-Instruct") + processor.save_pretrained(model_dir) + + model.save_pretrained(model_dir, safe_serialization=True) + config_path = model_dir / "config.json" + with open(config_path, "w") as f: + json.dump(HF_QWEN3_OMNI_MOE_TOY_MODEL_CONFIG, f, indent=2) + + print(f"Created MoE toy model at: {model_dir}") + return str(model_dir) + + @pytest.mark.run_only_on("GPU") + def test_qwen3_omni_moe_image_generation(self, qwen3_omni_moe_toy_model_path): + """ + Test Qwen3 Omni MoE toy model with image generation and EP=2. + Uses a small proxy MoE model instead of the full 30B model for fast testing. + Uses real image to test vision-language pipeline with corrected vision config. + + Args: + qwen3_omni_moe_toy_model_path: Path to the toy Qwen3 Omni MoE model (from fixture) + """ + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "examples/conversion/hf_to_megatron_generate_vlm.py", + f"--hf_model_path={qwen3_omni_moe_toy_model_path}", + "--prompt=Describe this image.", + "--ep=2", + ] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=Path(__file__).parent.parent.parent.parent.parent, + ) + + # Print output for debugging + print("\n" + "=" * 80) + print("STDOUT:") + print(result.stdout) + print("\n" + "=" * 80) + print("STDERR:") + print(result.stderr) + print("=" * 80 + "\n") + + if result.returncode != 0: + assert False, f"Qwen3 Omni MoE toy model generation failed with return code {result.returncode}" + + print("SUCCESS: Qwen3 Omni MoE toy model generation test completed successfully") + + except subprocess.TimeoutExpired: + assert False, "Qwen3 Omni MoE toy model generation test timed out after 5 minutes" + except Exception as e: + print(f"Error during Qwen3 Omni MoE toy model generation test: {e}") + raise diff --git a/tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_omni_model.py b/tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_omni_model.py new file mode 100644 index 0000000000..bc13a6b2eb --- /dev/null +++ b/tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_omni_model.py @@ -0,0 +1,450 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for Qwen3OmniMoe Model implementation. + +Run with: torchrun --nproc_per_node=8 -m pytest tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_model.py +Or for single GPU: pytest tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_model.py +""" + +import datetime +import os + +import numpy as np +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F +from megatron.core import parallel_state +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from transformers import AutoProcessor, Qwen3OmniMoeConfig + +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.model import Qwen3OmniMoeModel +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.transformer_config import Qwen3OmniTransformerConfig + + +@pytest.fixture(scope="module") +def processor(): + """Load HuggingFace processor once for all tests.""" + return AutoProcessor.from_pretrained("Qwen/Qwen3-Omni-30B-A3B-Instruct") + + +@pytest.fixture(scope="module") +def hf_config(): + """Load HuggingFace config once for all tests.""" + return Qwen3OmniMoeConfig.from_pretrained("Qwen/Qwen3-Omni-30B-A3B-Instruct") + + +@pytest.fixture +def random_image(): + """Generate a random PIL image.""" + return np.random.randint(0, 255, size=(24, 24, 3), dtype=np.uint8) + + +@pytest.fixture +def random_video(): + """Generate a random video.""" + return np.random.randint(0, 255, size=(2, 3, 24, 44), dtype=np.uint8) + + +@pytest.fixture +def random_audio(): + """Generate a random audio.""" + return np.random.randint(-1, 32767, size=(800), dtype=np.int16) + + +class TestQwen3OmniMoeModel: + """Test suite for Qwen3OmniMoe Model.""" + + @classmethod + def setup_class(cls): + """Setup distributed process group once for all tests in this class.""" + if not dist.is_initialized(): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + os.environ["RNAK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + device_count = torch.cuda.device_count() + if device_count > 0: + torch.cuda.set_device(0) + + dist.init_process_group( + backend="nccl" if device_count > 0 else "gloo", + world_size=1, + rank=0, + timeout=datetime.timedelta(minutes=30), + ) + + @classmethod + def teardown_class(cls): + """Teardown distributed process group once after all tests in this class.""" + if dist.is_initialized(): + dist.destroy_process_group() + + def _setup_parallel_state(self, tp_size=1, ep_size=1, pp_size=1, cp_size=1): + """Setup Megatron parallel state with specified parallelism configuration. + + Args: + tp_size: Tensor model parallel size + ep_size: Expert model parallel size + pp_size: Pipeline model parallel size + cp_size: Context parallel size + """ + # Clean up any existing parallel state before initializing + if parallel_state.model_parallel_is_initialized(): + parallel_state.destroy_model_parallel() + + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=cp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=1, + ) + + model_parallel_cuda_manual_seed(123) + + def teardown_method(self): + """Teardown Megatron parallel state after each test method.""" + parallel_state.destroy_model_parallel() + + @staticmethod + def get_thinker_transformer_config(hf_config): + """Create a thinker transformer config for testing. + + Returns: + TransformerConfig: Configuration for the thinker model. + """ + return hf_config.thinker_config + + @staticmethod + def get_talker_transformer_config(hf_config): + """Create a talker transformer config for testing. + + Returns: + TransformerConfig: Configuration for the thinker model. + """ + return hf_config.talker_config + + @staticmethod + def get_code2wav_transformer_config(hf_config): + """Create a code2wav transformer config for testing. + + Returns: + TransformerConfig: Configuration for the thinker model. + """ + return hf_config.code2wav_config + + @staticmethod + def get_language_transformer_config(hf_config): + """Create a language transformer config for testing. + + Uses actual Qwen3-Omni-30B-A3B model sizes to ensure compatibility + with the vision model output (2048 hidden size). + + Args: + hf_config: HuggingFace config object. + + Returns: + Qwen3OmniTransformerConfig: Configuration for the language model. + """ + thinker_config = hf_config.thinker_config + return Qwen3OmniTransformerConfig( + # Use actual model dimensions from HF config + num_layers=4, # Reduced for testing (actual: thinker_config.text_config.num_hidden_layers) + hidden_size=thinker_config.text_config.hidden_size, # Must match vision output: 2048 + num_attention_heads=thinker_config.text_config.num_attention_heads, + num_query_groups=thinker_config.text_config.num_key_value_heads, + kv_channels=thinker_config.text_config.hidden_size // thinker_config.text_config.num_attention_heads, + ffn_hidden_size=thinker_config.text_config.intermediate_size, + # Qwen3-Omni specific + vocab_size=thinker_config.text_config.vocab_size, + language_max_sequence_length=thinker_config.text_config.max_position_embeddings, + # Vision parameters + patch_size=thinker_config.vision_config.patch_size, + temporal_patch_size=thinker_config.vision_config.temporal_patch_size, + in_channels=thinker_config.vision_config.in_channels, + spatial_merge_size=thinker_config.vision_config.spatial_merge_size, + out_hidden_size=thinker_config.text_config.hidden_size, # Vision output = language input + # RoPE settings + rotary_base=thinker_config.text_config.rope_theta, + rotary_percent=1.0, + mrope_section=thinker_config.text_config.rope_scaling.get("mrope_section", [24, 20, 20]), + hf_text_config=thinker_config.text_config, + # Training settings + normalization="RMSNorm", + activation_func=F.silu, + gated_linear_unit=True, + add_bias_linear=False, + add_qkv_bias=True, + layernorm_epsilon=thinker_config.text_config.rms_norm_eps, + bf16=False, + use_cpu_initialization=True, + hidden_dropout=0.0, + attention_dropout=thinker_config.text_config.attention_dropout, + num_moe_experts=2, # Reduced for testing (actual: thinker_config.text_config.num_experts) + ) + + @staticmethod + def get_language_model_layer_spec(): + """Create a GPT layer spec for the language model. + + Returns: + ModuleSpec: Layer specification for transformer layers. + """ + language_model_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=2, # Reduced for testing (actual: hf_config.thinker_config.text_config.num_experts) + moe_grouped_gemm=True, + qk_layernorm=True, + fp8=False, + ) + return language_model_layer_spec + + @staticmethod + def get_data_batch(processor, random_image): + """Generate a batch of data for model forward pass. + + Args: + processor: HuggingFace processor. + random_image: Random PIL image. + + Returns: + dict: A dictionary containing all inputs needed for model forward pass: + - input_ids: Token IDs [batch, seq_len] + - attention_mask: Attention mask [batch, seq_len] + - pixel_values: Image pixel values [batch, channels, height, width] + - image_grid_thw: Image grid dimensions [num_images, 3] (temporal, height, width) + - pixel_values_videos: Video pixel values (None for images only) + - video_grid_thw: Video grid dimensions (None for images only) + - input_feature: Audio values (None if no audio) + - feature_attention_mask: Audio attention mask (None if no audio) + - video_second_per_grid: Video seconds per grid + """ + # Create a sample message with image and text + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": random_image, # Pass PIL Image directly + }, + { + "type": "video", + "video": random_video, + }, + { + "type": "audio", + "audio": random_audio, + }, + {"type": "text", "text": "Describe this image, video and aduio."}, + ], + } + ] + + # Process inputs using HuggingFace processor + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + + batch = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs.get("attention_mask"), + "pixel_values": inputs.get("pixel_values"), + "image_grid_thw": inputs.get("image_grid_thw"), + "pixel_values_videos": inputs.get("pixel_values_videos"), + "video_grid_thw": inputs.get("video_grid_thw"), + "input_features": inputs.get("input_features"), + "feature_attention_mask": inputs.get("feature_attention_mask"), + "video_second_per_grid": inputs.get("video_second_per_grid"), + "position_ids": None, + "labels": None, + } + + # Move tensors to CUDA if available + if torch.cuda.is_available(): + for key, value in batch.items(): + if value is not None and isinstance(value, torch.Tensor): + batch[key] = value.cuda() + + return batch + + @pytest.mark.timeout(50) + @pytest.mark.parametrize( + "freeze_all", + [True, False], + ) + def test_model_freeze_api(self, freeze_all, hf_config): + """Test model freeze API.""" + self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + assert pg_collection is not None + assert pg_collection.tp is not None + assert pg_collection.pp is not None + assert pg_collection.cp is not None + assert pg_collection.ep is not None + assert pg_collection.embd is not None + + language_transformer_config = self.get_language_transformer_config(hf_config) + language_model_layer_spec = self.get_language_model_layer_spec() + thinker_transformer_config = self.get_thinker_transformer_config(hf_config) + talker_transformer_config = self.get_talker_transformer_config(hf_config) + code2wav_transformer_config = self.get_code2wav_transformer_config(hf_config) + + model = Qwen3OmniMoeModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + talker_transformer_config=talker_transformer_config, + code2wav_transformer_config=code2wav_transformer_config, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True, + pg_collection=pg_collection, + ) + + if torch.cuda.is_available(): + model.to("cuda") + + model.freeze( + freeze_language_model=freeze_all, + freeze_vision_model=freeze_all, + freeze_vision_projection=freeze_all, + freeze_audio_model=freeze_all, + ) + + for name, param in model.named_parameters(): + assert param.requires_grad != freeze_all, f"{name=}" + + @pytest.mark.timeout(50) + def test_shared_embedding_or_output_weight(self, hf_config): + """Test shared_embedding_or_output_weight method.""" + self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + assert pg_collection is not None + assert pg_collection.tp is not None + assert pg_collection.pp is not None + assert pg_collection.cp is not None + assert pg_collection.ep is not None + assert pg_collection.embd is not None + + language_transformer_config = self.get_language_transformer_config(hf_config) + language_model_layer_spec = self.get_language_model_layer_spec() + thinker_transformer_config = self.get_thinker_transformer_config(hf_config) + talker_transformer_config = self.get_talker_transformer_config(hf_config) + code2wav_transformer_config = self.get_code2wav_transformer_config(hf_config) + + # Test with add_decoder=True + model = Qwen3OmniMoeModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + talker_transformer_config=talker_transformer_config, + code2wav_transformer_config=code2wav_transformer_config, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True, + pg_collection=pg_collection, + ) + weight = model.shared_embedding_or_output_weight() + assert weight is not None + + # Test with add_decoder=False + model = Qwen3OmniMoeModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + talker_transformer_config=talker_transformer_config, + code2wav_transformer_config=code2wav_transformer_config, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=False, + pg_collection=pg_collection, + ) + weight_no_decoder = model.shared_embedding_or_output_weight() + assert weight_no_decoder is None + + @pytest.mark.timeout(50) + def test_set_input_tensor(self, hf_config): + """Test set_input_tensor method.""" + self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + assert pg_collection is not None + assert pg_collection.tp is not None + assert pg_collection.pp is not None + assert pg_collection.cp is not None + assert pg_collection.ep is not None + assert pg_collection.embd is not None + + language_transformer_config = self.get_language_transformer_config(hf_config) + language_model_layer_spec = self.get_language_model_layer_spec() + thinker_transformer_config = self.get_thinker_transformer_config(hf_config) + talker_transformer_config = self.get_talker_transformer_config(hf_config) + code2wav_transformer_config = self.get_code2wav_transformer_config(hf_config) + + # Test with pre_process=True + model = Qwen3OmniMoeModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + talker_transformer_config=talker_transformer_config, + code2wav_transformer_config=code2wav_transformer_config, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True, + pg_collection=pg_collection, + ) + + test_tensor = torch.randn(2, 4, 2048) + + # Test with single tensor (not a list) + model.set_input_tensor([test_tensor]) + assert model.thinker.encoder_hidden_state is not None + + # Test with pre_process=True + model = Qwen3OmniMoeModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=language_model_layer_spec, + thinker_transformer_config=thinker_transformer_config, + talker_transformer_config=talker_transformer_config, + code2wav_transformer_config=code2wav_transformer_config, + parallel_output=True, + pre_process=False, + post_process=True, + add_encoder=True, + add_decoder=True, + pg_collection=pg_collection, + ) + + # This should set the input tensor on the language model instead + model.set_input_tensor([test_tensor]) + # No assertion here as it sets internal state diff --git a/tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_omni_rope.py b/tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_omni_rope.py new file mode 100644 index 0000000000..8ce769754f --- /dev/null +++ b/tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_omni_rope.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run with: pytest tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_rope.py +""" + +import torch + +from megatron.bridge.models.qwen_omni.modeling_qwen3_omni.rope import get_rope_index + + +class TestQwen3OmniMoeRope: + """Test suite for Qwen3OmniMoe utility functions.""" + + def test_get_rope_index_text_only(self): + """Test get_rope_index with text-only input.""" + batch_size, seq_len = 2, 8 + input_ids = torch.randint(0, 1000, (batch_size, seq_len)) + + position_ids, deltas = get_rope_index( + spatial_merge_size=2, + image_token_id=151655, + video_token_id=151656, + audio_token_id=151675, + vision_start_token_id=151652, + audio_start_token_id=151669, + position_id_per_seconds=13, + input_ids=input_ids, + ) + + assert position_ids.shape == (3, batch_size, seq_len) + assert deltas.shape == (batch_size, 1) + + def test_get_rope_index_with_attention_mask(self): + """Test get_rope_index with attention mask""" + batch_size, seq_len = 1, 16 + input_ids = torch.randint(0, 1000, (batch_size, seq_len)) + attention_mask = torch.ones((batch_size, seq_len)) + + position_ids, deltas = get_rope_index( + spatial_merge_size=2, + image_token_id=151655, + video_token_id=151656, + audio_token_id=151675, + vision_start_token_id=151652, + audio_start_token_id=151669, + position_id_per_seconds=13, + input_ids=input_ids, + attention_mask=attention_mask, + ) + + assert position_ids.shape == (3, batch_size, seq_len) + assert deltas.shape == (batch_size, 1) + + def test_get_rope_index_with_image(self): + """Test get_rope_index with image grid""" + batch_size, seq_len = 1, 16 + input_ids = torch.randint(0, 1000, (batch_size, seq_len)) + # Insert image tokens + input_ids[0, 4] = 151652 # vision_start_token_id + input_ids[0, 5] = 151655 # image_token_id + image_grid_thw = torch.tensor([[1, 4, 4]]) + + position_ids, deltas = get_rope_index( + spatial_merge_size=2, + image_token_id=151655, + video_token_id=151656, + audio_token_id=151675, + vision_start_token_id=151652, + audio_start_token_id=151669, + position_id_per_seconds=13, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + ) + + assert position_ids.shape == (3, batch_size, seq_len) + assert deltas.shape == (batch_size, 1) + + def test_get_rope_index_with_video(self): + """Test get_rope_index with video grid""" + batch_size, seq_len = 1, 16 + input_ids = torch.randint(0, 1000, (batch_size, seq_len)) + # Insert vidion tokens + input_ids[0, 4] = 151652 # vision_start_token_id + input_ids[0, 5] = 151656 # video_token_id + video_grid_thw = torch.tensor([[1, 4, 4]]) + second_per_grids = torch.tensor([2]) + + position_ids, deltas = get_rope_index( + spatial_merge_size=2, + image_token_id=151655, + video_token_id=151656, + audio_token_id=151675, + vision_start_token_id=151652, + audio_start_token_id=151669, + position_id_per_seconds=13, + input_ids=input_ids, + video_grid_thw=video_grid_thw, + second_per_grids=second_per_grids, + ) + + assert position_ids.shape == (3, batch_size, seq_len) + assert deltas.shape == (batch_size, 1) + + def test_get_rope_index_with_audio_in_video(self): + """Test get_rope_index with audio grid""" + batch_size, seq_len = 1, 16 + input_ids = torch.randint(0, 1000, (batch_size, seq_len)) + # Insert vidion tokens + input_ids[0, 4] = 151652 # vision_start_token_id + input_ids[0, 5] = 151669 # audio_start_token_id + input_ids[0, 6] = 151656 # video_token_id + input_ids[0, 7] = 151675 # audio_token_id + video_grid_thw = torch.tensor([[1, 4, 4]]) + audio_seqlens = torch.tensor([1]) + second_per_grids = torch.tensor([2]) + + position_ids, deltas = get_rope_index( + spatial_merge_size=2, + image_token_id=151655, + video_token_id=151656, + audio_token_id=151675, + vision_start_token_id=151652, + audio_start_token_id=151669, + position_id_per_seconds=13, + input_ids=input_ids, + video_grid_thw=video_grid_thw, + use_audio_in_video=True, + audio_seqlens=audio_seqlens, + second_per_grids=second_per_grids, + ) + + assert position_ids.shape == (3, batch_size, seq_len) + assert deltas.shape == (batch_size, 1)