Support Qwen3-ASR Megatron Bridge#2836
Conversation
- Add Qwen3-ASR model bridge with audio encoder, thinker model, and RoPE - Add transformer config conversion for Qwen3-ASR - Add provider bridge with kv_channels support - Fix auto_bridge to support custom models not in transformers - Fix bridge registration to use string-based source - Fix dtype mismatch in audio encoder forward pass - Add unit tests for Qwen3-ASR model Signed-off-by: zhangyuekai <zhangyuekai@foxmail.com> Signed-off-by: root <root@h20-2.cm.cluster>
|
Megatron Inference Results: AUDIO_URL_3="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/1272-128104-0000.flac"
|
📝 WalkthroughWalkthroughThis pull request introduces Qwen3-ASR (automatic speech recognition) support to the Megatron framework. It adds a complete new model implementation that combines a HuggingFace audio encoder with a Qwen3-based language model, includes a bridge for converting HuggingFace Qwen3-ASR models, and integrates the new architecture into the framework with fallback registry support. Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Model Forward
participant ThinkerModel as Qwen3ASRThinkerModel
participant AudioEnc as Audio Encoder
participant LanguageModel as Qwen3VLGPTModel
participant RopeIdx as RopeIndex
Client->>ThinkerModel: forward(input_ids, input_features, ...)
activate ThinkerModel
ThinkerModel->>RopeIdx: get_rope_index(input_ids, attention_mask)
RopeIdx-->>ThinkerModel: position_ids, mrope_deltas
ThinkerModel->>AudioEnc: get_audio_features(input_features)
activate AudioEnc
AudioEnc-->>ThinkerModel: audio_embeddings
deactivate AudioEnc
ThinkerModel->>ThinkerModel: merge audio embeddings at audio token positions
ThinkerModel->>LanguageModel: forward(decoder_input, position_ids, attention_mask, labels, ...)
activate LanguageModel
LanguageModel-->>ThinkerModel: output logits/loss
deactivate LanguageModel
ThinkerModel-->>Client: model output
deactivate ThinkerModel
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can approve the review once all CodeRabbit's comments are resolved.Enable the |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (3)
tests/unit_tests/models/qwen3_asr/modeling_qwen3_asr/test_qwen3_asr_model.py (1)
84-85: Add the unit-test marker.These new tests live under
tests/unit_tests/..., but the class is only tagged withtimeout.As per coding guidelines, `tests/**/*.py`: Use pytest markers to categorize tests (unit, integration, system).Suggested fix
-@pytest.mark.timeout(30) +@pytest.mark.unit +@pytest.mark.timeout(30) class TestQwen3ASRModel:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/models/qwen3_asr/modeling_qwen3_asr/test_qwen3_asr_model.py` around lines 84 - 85, The test class TestQwen3ASRModel is missing the unit test marker; update the class decorators to include the pytest marker for unit tests (e.g., add `@pytest.mark.unit` alongside the existing `@pytest.mark.timeout`(30)), or define a module-level pytestmark = [pytest.mark.unit, pytest.mark.timeout(30)] so the tests under TestQwen3ASRModel are correctly categorized as unit tests per the test guidelines.src/megatron/bridge/models/qwen3_asr/qwen3_asr_provider.py (1)
48-48: Simplifydefault_factory- lambda wrapper is unnecessary.The lambda is redundant when the factory is just calling the class constructor with no arguments.
Suggested simplification
- thinker_config: Qwen3ASRThinkerConfig = field(default_factory=lambda: Qwen3ASRThinkerConfig()) + thinker_config: Qwen3ASRThinkerConfig = field(default_factory=Qwen3ASRThinkerConfig)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/models/qwen3_asr/qwen3_asr_provider.py` at line 48, The field definition for thinker_config uses a redundant lambda as default_factory; update the dataclass field declaration to use the constructor directly by setting default_factory=Qwen3ASRThinkerConfig so replace "default_factory=lambda: Qwen3ASRThinkerConfig()" with "default_factory=Qwen3ASRThinkerConfig" on the thinker_config field to simplify the code.src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/model.py (1)
83-97: Add type hints for untyped parameters.Per coding guidelines, all function arguments should have type hints. The following parameters are missing type annotations:
input_featuresfeature_attention_maskaudio_feature_lengthsSuggested type annotations
def forward( self, input_ids: torch.Tensor, - input_features=None, + input_features: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, loss_mask: torch.Tensor | None = None, inference_params: InferenceParams | None = None, packed_seq_params: PackedSeqParams | None = None, extra_block_kwargs: dict | None = None, - feature_attention_mask=None, - audio_feature_lengths=None, + feature_attention_mask: torch.Tensor | None = None, + audio_feature_lengths: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/model.py` around lines 83 - 97, The forward signature in modeling_qwen3_asr.model.py is missing type hints for input_features, feature_attention_mask, and audio_feature_lengths; update the forward method signature (function name: forward) to add explicit types—e.g. input_features: torch.Tensor | None, feature_attention_mask: torch.Tensor | None, and audio_feature_lengths: torch.Tensor | Sequence[int] | None (or int[]-like type your codebase uses)—and keep existing types for input_ids, position_ids, attention_mask, labels, loss_mask, inference_params: InferenceParams | None, and packed_seq_params: PackedSeqParams | None so static checkers and IDEs can validate usage. Ensure imports/types used are available in the module or add them if necessary.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/models/__init__.py`:
- Around line 124-128: The top-level imports of Qwen3ASRBridge, Qwen3ASRModel,
and Qwen3ASRModelProvider currently force-import the external qwen_asr package;
either add qwen_asr to project dependencies in pyproject.toml or guard the
imports with lazy loading—implement a module-level __getattr__ (or try/except
import inside a function) that imports and returns Qwen3ASRBridge,
Qwen3ASRModel, and Qwen3ASRModelProvider on-demand, and apply the same change
for the other import block referenced (lines around the second occurrence).
In `@src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/thinker_model.py`:
- Around line 233-261: The bug is that position_ids remain global while
combined_embeddings is split for context-parallel (CP) ranks; when cp_size>1 and
packed_seq_params is None you must shard position_ids the same way as
combined_embeddings so each CP rank gets matching RoPE positions. After
computing position_ids via get_rope_index (and before applying SP
padding/scatter), detect the same cp_size/cp_rank condition used for
combined_embeddings and call the same splitter (split_data_cp_rank) on
position_ids (preserving the same dims/order), then continue with the existing
SP padding/replicate padding logic so position_ids and combined_embeddings stay
aligned; ensure you reference combined_embeddings, split_data_cp_rank, cp_size,
cp_rank, position_ids, get_rope_index, and packed_seq_params when editing.
- Around line 55-84: The constructor currently types pg_collection as optional
but immediately dereferences it (self.cp_group = pg_collection.cp etc.), causing
AttributeError; either make pg_collection a required parameter (remove the "=
None" and update callers) or guard/mater ialize defaults before use by resolving
a default ProcessGroupCollection (e.g., via your parallel_state default getter)
and assigning self.pg_collection to that resolved instance before setting
self.cp_group, self.tp_group, self.pp_group and self.embd_group; also keep the
existing assert for embd but ensure it runs against the non-None
self.pg_collection.
- Around line 163-180: In the loop in thinker_model.py where you iterate "for
input_feature, feature_len in zip(input_features, feature_lens):", convert the
0-dim tensor feature_len to a Python int before using it for slicing (e.g., use
feature_len.item() for the slice limit) and pass a proper 1-D tensor/shape to
the audio model for feature_lens (e.g., construct a
torch.tensor([feature_len_value], device=input_feature.device, dtype=torch.long)
or otherwise match expected type), and change the zip call to
zip(input_features, feature_lens, strict=True) to enforce batch-size
consistency; update references to feature_len used for slicing and for
feature_lens argument accordingly.
In
`@tests/unit_tests/models/qwen3_asr/modeling_qwen3_asr/test_qwen3_asr_model.py`:
- Around line 88-113: The teardown currently destroys the global distributed
process group even when setup_class skipped initialization; modify setup_class
to record ownership (e.g., set a class attribute like cls._owns_process_group =
True only when you call dist.init_process_group) and ensure teardown_class only
calls dist.destroy_process_group() if dist.is_initialized() and
cls._owns_process_group is True; set the flag to False when skipping init so you
don't destroy pre-initialized groups, and clear or reset the flag after
destroying in teardown_class.
---
Nitpick comments:
In `@src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/model.py`:
- Around line 83-97: The forward signature in modeling_qwen3_asr.model.py is
missing type hints for input_features, feature_attention_mask, and
audio_feature_lengths; update the forward method signature (function name:
forward) to add explicit types—e.g. input_features: torch.Tensor | None,
feature_attention_mask: torch.Tensor | None, and audio_feature_lengths:
torch.Tensor | Sequence[int] | None (or int[]-like type your codebase uses)—and
keep existing types for input_ids, position_ids, attention_mask, labels,
loss_mask, inference_params: InferenceParams | None, and packed_seq_params:
PackedSeqParams | None so static checkers and IDEs can validate usage. Ensure
imports/types used are available in the module or add them if necessary.
In `@src/megatron/bridge/models/qwen3_asr/qwen3_asr_provider.py`:
- Line 48: The field definition for thinker_config uses a redundant lambda as
default_factory; update the dataclass field declaration to use the constructor
directly by setting default_factory=Qwen3ASRThinkerConfig so replace
"default_factory=lambda: Qwen3ASRThinkerConfig()" with
"default_factory=Qwen3ASRThinkerConfig" on the thinker_config field to simplify
the code.
In
`@tests/unit_tests/models/qwen3_asr/modeling_qwen3_asr/test_qwen3_asr_model.py`:
- Around line 84-85: The test class TestQwen3ASRModel is missing the unit test
marker; update the class decorators to include the pytest marker for unit tests
(e.g., add `@pytest.mark.unit` alongside the existing `@pytest.mark.timeout`(30)),
or define a module-level pytestmark = [pytest.mark.unit,
pytest.mark.timeout(30)] so the tests under TestQwen3ASRModel are correctly
categorized as unit tests per the test guidelines.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: dd9f3779-471e-44ea-80e2-f4c687b581c9
📒 Files selected for processing (13)
src/megatron/bridge/models/__init__.pysrc/megatron/bridge/models/conversion/auto_bridge.pysrc/megatron/bridge/models/qwen3_asr/__init__.pysrc/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/__init__.pysrc/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/model.pysrc/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/rope.pysrc/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/thinker_model.pysrc/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/transformer_config.pysrc/megatron/bridge/models/qwen3_asr/qwen3_asr_bridge.pysrc/megatron/bridge/models/qwen3_asr/qwen3_asr_provider.pytests/unit_tests/models/qwen3_asr/__init__.pytests/unit_tests/models/qwen3_asr/modeling_qwen3_asr/__init__.pytests/unit_tests/models/qwen3_asr/modeling_qwen3_asr/test_qwen3_asr_model.py
src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/thinker_model.py
Show resolved
Hide resolved
src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/thinker_model.py
Outdated
Show resolved
Hide resolved
src/megatron/bridge/models/qwen3_asr/modeling_qwen3_asr/thinker_model.py
Show resolved
Hide resolved
tests/unit_tests/models/qwen3_asr/modeling_qwen3_asr/test_qwen3_asr_model.py
Show resolved
Hide resolved
Signed-off-by: root <zhangyuekai@foxmail.com>
|
/claude review |
tests/unit_tests/models/qwen3_asr/modeling_qwen3_asr/test_qwen3_asr_model.py
Show resolved
Hide resolved
|
Please check AI comments. |
Signed-off-by: root <zhangyuekai@foxmail.com>
Signed-off-by: root <zhangyuekai@foxmail.com>
Done. Many thanks for the review. |
|
@yuekaizhang need to add a functional tests, and a L0 bash. This is required for all new models. |
Signed-off-by: root <zhangyuekai@foxmail.com>
Done. |
|
@chtruong814 @yaoyu-33 All issues has been resolved. I was wondering if you could help review and do CI/CD test please. Thanks. |
|
/ok to test 78e019f |
What does this PR do ?
Support https://github.com/QwenLM/Qwen3-ASR in M-bridge.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Tests