Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/megatron/bridge/inference/vlm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def setup_model_and_tokenizer(
return inference_wrapped_model, processor


def _expose_decoder_from_language_model(model):
"""Recursively get language_model from model and expose decoder, handling wrapped modules."""
current = model
while hasattr(current, "module"):
current = current.module

if hasattr(current, "language_model"):
language_model = current.language_model
current.decoder = language_model.decoder


def setup_inference_wrapper(
model,
tokenizer,
Expand All @@ -131,6 +142,8 @@ def setup_inference_wrapper(
wrapper_cls = QwenVLInferenceWrapper
if isinstance(config, Qwen25VLModelProvider):
hidden_size = config.hidden_size
# Expose decoder for MCore Infernce Engine compatibility (used by get_mamba_inference_state_config_from_model)
_expose_decoder_from_language_model(mcore_model)
else:
hidden_size = config.language_transformer_config.hidden_size
else:
Expand Down
3 changes: 0 additions & 3 deletions src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,6 @@ def __init__(
self.share_embeddings_and_output_weights = config.share_embeddings_and_output_weights
self.shared_embedding_or_output_weight = self.language_model.shared_embedding_or_output_weight

# Expose decoder for MCore Infernce Engine compatibility (used by get_mamba_inference_state_config_from_model)
self.decoder = self.language_model.decoder

# Bind methods from HF's Qwen2_5_VLModel to this instance
# get_placeholder_mask is only available in transformers 4.55+
if is_transformers_min_version("4.55.0"):
Expand Down
46 changes: 41 additions & 5 deletions tests/unit_tests/inference/vlm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,25 +236,54 @@ class TestSetupInferenceWrapper:

@patch("megatron.bridge.inference.vlm.base.QwenVLInferenceWrapper")
def test_setup_inference_wrapper_qwen25(self, mock_wrapper_cls, mock_tokenizer):
mock_model = MagicMock()
"""Test Qwen25 setup with module.language_model.decoder structure."""

# Create mock objects with nested structure
class MockObject:
pass

mock_decoder = MagicMock()

# Build the nested structure: model.module.language_model.decoder
mock_language_model = MockObject()
mock_language_model.decoder = mock_decoder

mock_module = MockObject()
mock_module.language_model = mock_language_model

mock_model = MockObject()
mock_model.module = mock_module
mock_model.config = MagicMock(spec=Qwen25VLModelProvider)
mock_model.config.hidden_size = 1024
mock_model.cuda = MagicMock(return_value=mock_model)
mock_model.to = MagicMock(return_value=mock_model)
mock_model.eval = MagicMock()

_wrapper = setup_inference_wrapper(mock_model, mock_tokenizer)

# Verify decoder was exposed at module level
assert hasattr(mock_module, "decoder")
assert mock_module.decoder is mock_decoder

mock_wrapper_cls.assert_called_once()
# Check InferenceWrapperConfig was created with correct hidden_size
# Args are positional: (model, InferenceWrapperConfig)
call_args = mock_wrapper_cls.call_args
inference_config = call_args[0][1] # Second positional argument
inference_config = call_args[0][1]
assert inference_config.hidden_size == 1024

@patch("megatron.bridge.inference.vlm.base.QwenVLInferenceWrapper")
def test_setup_inference_wrapper_qwen3(self, mock_wrapper_cls, mock_tokenizer):
mock_model = MagicMock()
# Create a simple object without module attribute to avoid infinite loop
class MockObject:
pass

mock_model = MockObject()
mock_model.config = MagicMock(spec=Qwen3VLModelProvider)
mock_model.config.language_transformer_config = MagicMock()
mock_model.config.language_transformer_config.hidden_size = 2048
mock_model.cuda = MagicMock(return_value=mock_model)
mock_model.to = MagicMock(return_value=mock_model)
mock_model.eval = MagicMock()

_wrapper = setup_inference_wrapper(mock_model, mock_tokenizer)

Expand All @@ -266,8 +295,15 @@ def test_setup_inference_wrapper_qwen3(self, mock_wrapper_cls, mock_tokenizer):
assert inference_config.hidden_size == 2048

def test_setup_inference_wrapper_invalid(self, mock_tokenizer):
mock_model = MagicMock()
# Create a simple object without module attribute to avoid infinite loop
class MockObject:
pass

mock_model = MockObject()
mock_model.config = MagicMock() # Not Qwen config
mock_model.cuda = MagicMock(return_value=mock_model)
mock_model.to = MagicMock(return_value=mock_model)
mock_model.eval = MagicMock()

with pytest.raises(ValueError):
setup_inference_wrapper(mock_model, mock_tokenizer)
Expand Down
Loading