Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/megatron/bridge/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@
Qwen25ModelProvider72B,
Qwen25ModelProvider500M,
)
from megatron.bridge.models.qwen3_asr import (
Qwen3ASRBridge,
Qwen3ASRModel,
Qwen3ASRModelProvider,
)
Comment thread
yuekaizhang marked this conversation as resolved.
from megatron.bridge.models.qwen_omni import (
Qwen25OmniBridge,
Qwen25OmniModel,
Expand Down Expand Up @@ -243,6 +248,10 @@
"NemotronVLModel",
"NemotronVLBridge",
"NemotronNano12Bv2VLModelProvider",
# ASR Models
"Qwen3ASRBridge",
"Qwen3ASRModel",
"Qwen3ASRModelProvider",
# Omni Models
"Qwen25OmniModel",
"Qwen25OmniBridge",
Expand Down
7 changes: 7 additions & 0 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,13 @@ def _causal_lm_architecture(self):
try:
Comment thread
yuekaizhang marked this conversation as resolved.
return getattr(transformers, resolved_arch)
except AttributeError:
# Fall back to string-based lookup for custom models not in transformers
# (e.g. Qwen3ASRForConditionalGeneration from qwen_asr package).
# This mirrors the auto_map path and works with string-registered bridges.
if hasattr(model_bridge.get_model_bridge, "_exact_types"):
registry = model_bridge.get_model_bridge._exact_types
if resolved_arch in registry:
return resolved_arch
Comment thread
yuekaizhang marked this conversation as resolved.
Outdated
raise ValueError(
f"\n✗ Architecture class '{resolved_arch}' not found in transformers\n\n"
f"This could mean:\n"
Expand Down
24 changes: 24 additions & 0 deletions src/megatron/bridge/models/qwen3_asr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from megatron.bridge.models.qwen3_asr.modeling_qwen3_asr.model import Qwen3ASRModel
from megatron.bridge.models.qwen3_asr.qwen3_asr_bridge import Qwen3ASRBridge
from megatron.bridge.models.qwen3_asr.qwen3_asr_provider import Qwen3ASRModelProvider


__all__ = [
"Qwen3ASRBridge",
"Qwen3ASRModel",
"Qwen3ASRModelProvider",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
111 changes: 111 additions & 0 deletions src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from megatron.core import InferenceParams
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec

from megatron.bridge.models.qwen3_asr.modeling_qwen3_asr.thinker_model import Qwen3ASRThinkerModel
from megatron.bridge.models.qwen3_asr.modeling_qwen3_asr.transformer_config import Qwen3ASRTransformerConfig


class Qwen3ASRModel(MegatronModule):
"""Qwen3-ASR Model.

Top-level wrapper that delegates to Qwen3ASRThinkerModel.
Audio-only model (no vision/video), follows Qwen2.5-Omni pattern simplified for ASR.
"""

def __init__(
self,
language_transformer_config: Qwen3ASRTransformerConfig,
language_transformer_layer_spec: ModuleSpec,
thinker_transformer_config,
parallel_output: bool = True,
pre_process: bool = True,
post_process: bool = True,
add_encoder: bool = True,
add_decoder: bool = True,
pg_collection: ProcessGroupCollection | None = None,
) -> None:
super().__init__(config=language_transformer_config)

self.thinker = Qwen3ASRThinkerModel(
language_transformer_config,
language_transformer_layer_spec,
thinker_transformer_config,
parallel_output,
pre_process,
post_process,
add_encoder,
add_decoder,
pg_collection,
)

def shared_embedding_or_output_weight(self):
"""This is a convenience method to surface the language model's word embeddings, which is
necessary for `finalize_model_grads._allreduce_word_embedding_grads`."""
return self.thinker.shared_embedding_or_output_weight()

def set_input_tensor(self, input_tensor) -> None:
return self.thinker.set_input_tensor(input_tensor)

def freeze(
self,
freeze_language_model: bool = False,
freeze_audio_model: bool = False,
):
"""Freeze model modules.

Args:
freeze_language_model (bool): Freeze the language model module.
freeze_audio_model (bool): Freeze the audio model module.
"""
return self.thinker.freeze(
freeze_language_model,
freeze_audio_model,
)

def forward(
self,
input_ids: torch.Tensor,
input_features: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
loss_mask: torch.Tensor | None = None,
inference_params: InferenceParams | None = None,
packed_seq_params: PackedSeqParams | None = None,
extra_block_kwargs: dict | None = None,
feature_attention_mask: torch.Tensor | None = None,
audio_feature_lengths: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
return self.thinker(
input_ids=input_ids,
input_features=input_features,
position_ids=position_ids,
attention_mask=attention_mask,
labels=labels,
loss_mask=loss_mask,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
extra_block_kwargs=extra_block_kwargs,
feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths,
**kwargs,
)
48 changes: 48 additions & 0 deletions src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


def get_rope_index(
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index for Qwen3-ASR.

Simplified version for audio-only model: just cumulative position IDs
expanded to 3 MRoPE dimensions. No special vision/video handling needed.

Ported from HF Qwen3ASRPreTrainedModelForConditionalGeneration.get_rope_index.

Args:
input_ids: Input token IDs of shape (batch_size, sequence_length).
attention_mask: Attention mask of shape (batch_size, sequence_length).

Returns:
position_ids: Position IDs of shape (3, batch_size, sequence_length).
mrope_position_deltas: Position deltas of shape (batch_size, 1).
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

attention_mask = attention_mask.to(input_ids.device)
position_ids = attention_mask.float().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)

return position_ids, mrope_position_deltas
Loading