Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions examples/models/audio_lm/inference_qwen3_asr.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env bash
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Usage:
# bash examples/models/audio_lm/inference_qwen3_asr.sh

set -e

export HF_MODEL="Qwen/Qwen3-ASR-1.7B"
export MEGATRON_PATH="examples/models/audio_lm/qwen3_asr"


AUDIO_URL="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/1272-128104-0000.flac"

echo "============================================"
echo "Qwen3-ASR Megatron Bridge Inference Test"
echo "============================================"

# Option 1: Direct inference from HuggingFace (no conversion)
echo ""
echo "Option 1: Direct inference from HuggingFace..."
echo "Audio: ${AUDIO_URL}"
echo ""

uv run --no-sync python -m torch.distributed.run --nproc_per_node=2 examples/conversion/hf_to_megatron_generate_audio_lm.py \
--hf_model_path ${HF_MODEL} \
--audio_url "${AUDIO_URL}" \
--prompt "" \
--tp 2 \
--max_new_tokens 50

# Option 2: Convert to Megatron format and run inference
# Uncomment the following to test checkpoint conversion workflow

echo ""
echo "Option 2: Converting HF checkpoint to Megatron format..."
uv run --no-sync python examples/conversion/convert_checkpoints.py import \
--hf-model ${HF_MODEL} \
--megatron-path ${MEGATRON_PATH}

echo ""
echo "Running inference on converted checkpoint..."
uv run --no-sync python -m torch.distributed.run examples/conversion/hf_to_megatron_generate_audio_lm.py \
--hf_model_path ${HF_MODEL} \
--megatron_model_path ${MEGATRON_PATH}/iter_0000000 \
--audio_url "${AUDIO_URL}" \
--prompt "" \
--max_new_tokens 50

echo ""
echo "============================================"
echo "Inference complete!"
echo "============================================"
9 changes: 9 additions & 0 deletions src/megatron/bridge/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@
Qwen2AudioModel,
Qwen2AudioModelProvider,
)
from megatron.bridge.models.qwen3_asr import (
Qwen3ASRBridge,
Qwen3ASRModel,
Qwen3ASRModelProvider,
)
from megatron.bridge.models.qwen_omni import (
Qwen25OmniBridge,
Qwen25OmniModel,
Expand Down Expand Up @@ -210,6 +215,10 @@
"NemotronVLModel",
"NemotronVLBridge",
"NemotronNano12Bv2VLModelProvider",
# ASR Models
"Qwen3ASRBridge",
"Qwen3ASRModel",
"Qwen3ASRModelProvider",
# Omni Models
"Qwen25OmniModel",
"Qwen25OmniBridge",
Expand Down
7 changes: 7 additions & 0 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,13 @@ def _causal_lm_architecture(self):
try:
return getattr(transformers, resolved_arch)
except AttributeError:
# Fall back to string-based lookup for custom models not in transformers
# (e.g. Qwen3ASRForConditionalGeneration from qwen_asr package).
# This mirrors the auto_map path and works with string-registered bridges.
if hasattr(model_bridge.get_model_bridge, "_exact_types"):
registry = model_bridge.get_model_bridge._exact_types
if resolved_arch in registry:
return resolved_arch
raise ValueError(
f"\n✗ Architecture class '{resolved_arch}' not found in transformers\n\n"
f"This could mean:\n"
Expand Down
25 changes: 25 additions & 0 deletions src/megatron/bridge/models/qwen3_asr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import megatron.bridge.models.qwen3_asr.hf_qwen3_asr # triggers AutoConfig.register("qwen3_asr", ...)
from megatron.bridge.models.qwen3_asr.modeling_qwen3_asr.model import Qwen3ASRModel
from megatron.bridge.models.qwen3_asr.qwen3_asr_bridge import Qwen3ASRBridge
from megatron.bridge.models.qwen3_asr.qwen3_asr_provider import Qwen3ASRModelProvider


__all__ = [
"Qwen3ASRBridge",
"Qwen3ASRModel",
"Qwen3ASRModelProvider",
]
11 changes: 11 additions & 0 deletions src/megatron/bridge/models/qwen3_asr/hf_qwen3_asr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Register with transformers Auto classes (replaces qwen_asr.inference registration)
from transformers import AutoConfig, AutoModel, AutoProcessor

from .configuration_qwen3_asr import Qwen3ASRAudioEncoderConfig, Qwen3ASRConfig, Qwen3ASRThinkerConfig
from .modeling_qwen3_asr import Qwen3ASRAudioEncoder, Qwen3ASRForConditionalGeneration
from .processing_qwen3_asr import Qwen3ASRProcessor


AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
Loading