Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/megatron/bridge/inference/vlm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def setup_model_and_tokenizer(
return inference_wrapped_model, processor


def _expose_decoder_from_language_model(model):
"""Recursively get language_model from model and expose decoder, handling wrapped modules."""
current = model
while hasattr(current, "module"):
current = current.module

if hasattr(current, "language_model"):
language_model = current.language_model
current.decoder = language_model.decoder
Comment on lines +115 to +123
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard against cyclic .module and add type hints/docstring.
The unwrap loop can hang if .module ever points to itself (or cycles). Also, this new helper should follow the typing + Google-docstring rules.

🔧 Proposed fix
-def _expose_decoder_from_language_model(model):
-    """Recursively get language_model from model and expose decoder, handling wrapped modules."""
-    current = model
-    while hasattr(current, "module"):
-        current = current.module
+def _expose_decoder_from_language_model(model: torch.nn.Module) -> None:
+    """Expose the decoder on the base model after unwrapping `.module` layers.
+
+    Args:
+        model: The potentially wrapped model instance.
+
+    Returns:
+        None.
+    """
+    current = model
+    while hasattr(current, "module"):
+        next_module = current.module
+        if next_module is None or next_module is current:
+            break
+        current = next_module

As per coding guidelines: Use Google style docstrings for classes and functions; Use type hints for function arguments and return types.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _expose_decoder_from_language_model(model):
"""Recursively get language_model from model and expose decoder, handling wrapped modules."""
current = model
while hasattr(current, "module"):
current = current.module
if hasattr(current, "language_model"):
language_model = current.language_model
current.decoder = language_model.decoder
def _expose_decoder_from_language_model(model: torch.nn.Module) -> None:
"""Expose the decoder on the base model after unwrapping `.module` layers.
Args:
model: The potentially wrapped model instance.
Returns:
None.
"""
current = model
while hasattr(current, "module"):
next_module = current.module
if next_module is None or next_module is current:
break
current = next_module
if hasattr(current, "language_model"):
language_model = current.language_model
current.decoder = language_model.decoder
🤖 Prompt for AI Agents
In `@src/megatron/bridge/inference/vlm/base.py` around lines 115 - 123, The helper
_expose_decoder_from_language_model currently unwraps wrapped modules by
following .module and can loop forever if .module cycles; update its signature
to include type hints (e.g., model: torch.nn.Module) and a return type of None,
add a Google-style docstring describing args and return, and modify the unwrap
loop to detect cycles by tracking visited objects (e.g., using a set of ids of
`current`) and break/return if a cycle is detected; after unwrapping, safely
access `language_model` and assign `current.decoder = language_model.decoder`
only if present.



def setup_inference_wrapper(
model,
tokenizer,
Expand All @@ -131,6 +142,8 @@ def setup_inference_wrapper(
wrapper_cls = QwenVLInferenceWrapper
if isinstance(config, Qwen25VLModelProvider):
hidden_size = config.hidden_size
# Expose decoder for MCore Infernce Engine compatibility (used by get_mamba_inference_state_config_from_model)
_expose_decoder_from_language_model(mcore_model)
else:
hidden_size = config.language_transformer_config.hidden_size
else:
Expand Down
3 changes: 0 additions & 3 deletions src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,6 @@ def __init__(
self.share_embeddings_and_output_weights = config.share_embeddings_and_output_weights
self.shared_embedding_or_output_weight = self.language_model.shared_embedding_or_output_weight

# Expose decoder for MCore Infernce Engine compatibility (used by get_mamba_inference_state_config_from_model)
self.decoder = self.language_model.decoder

# Bind methods from HF's Qwen2_5_VLModel to this instance
# get_placeholder_mask is only available in transformers 4.55+
if is_transformers_min_version("4.55.0"):
Expand Down
46 changes: 41 additions & 5 deletions tests/unit_tests/inference/vlm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,25 +236,54 @@ class TestSetupInferenceWrapper:

@patch("megatron.bridge.inference.vlm.base.QwenVLInferenceWrapper")
def test_setup_inference_wrapper_qwen25(self, mock_wrapper_cls, mock_tokenizer):
mock_model = MagicMock()
"""Test Qwen25 setup with module.language_model.decoder structure."""

# Create mock objects with nested structure
class MockObject:
pass

mock_decoder = MagicMock()

# Build the nested structure: model.module.language_model.decoder
mock_language_model = MockObject()
mock_language_model.decoder = mock_decoder

mock_module = MockObject()
mock_module.language_model = mock_language_model

mock_model = MockObject()
mock_model.module = mock_module
mock_model.config = MagicMock(spec=Qwen25VLModelProvider)
mock_model.config.hidden_size = 1024
mock_model.cuda = MagicMock(return_value=mock_model)
mock_model.to = MagicMock(return_value=mock_model)
mock_model.eval = MagicMock()

_wrapper = setup_inference_wrapper(mock_model, mock_tokenizer)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused _wrapper assignments to satisfy lint.
_wrapper is assigned but never used, and F841 will fail lint. Drop the assignment or use the value.

🧹 Proposed fix
-        _wrapper = setup_inference_wrapper(mock_model, mock_tokenizer)
+        setup_inference_wrapper(mock_model, mock_tokenizer)
-        _wrapper = setup_inference_wrapper(mock_model, mock_tokenizer)
+        setup_inference_wrapper(mock_model, mock_tokenizer)

Also applies to: 288-288

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 262-262: local variable '_wrapper' is assigned to but never used

(F841)

🤖 Prompt for AI Agents
In `@tests/unit_tests/inference/vlm/test_base.py` at line 262, The test assigns
`_wrapper = setup_inference_wrapper(mock_model, mock_tokenizer)` (and the same
at another spot) but never uses `_wrapper`, causing lint F841; fix by removing
the unused assignment and either call `setup_inference_wrapper(mock_model,
mock_tokenizer)` without assigning or, if the return is needed later, assign to
a used name (e.g., `wrapper`) or use the returned value; update the lines where
`_wrapper` is set (the `setup_inference_wrapper` calls) accordingly.


# Verify decoder was exposed at module level
assert hasattr(mock_module, "decoder")
assert mock_module.decoder is mock_decoder

mock_wrapper_cls.assert_called_once()
# Check InferenceWrapperConfig was created with correct hidden_size
# Args are positional: (model, InferenceWrapperConfig)
call_args = mock_wrapper_cls.call_args
inference_config = call_args[0][1] # Second positional argument
inference_config = call_args[0][1]
assert inference_config.hidden_size == 1024

@patch("megatron.bridge.inference.vlm.base.QwenVLInferenceWrapper")
def test_setup_inference_wrapper_qwen3(self, mock_wrapper_cls, mock_tokenizer):
mock_model = MagicMock()
# Create a simple object without module attribute to avoid infinite loop
class MockObject:
pass

mock_model = MockObject()
mock_model.config = MagicMock(spec=Qwen3VLModelProvider)
mock_model.config.language_transformer_config = MagicMock()
mock_model.config.language_transformer_config.hidden_size = 2048
mock_model.cuda = MagicMock(return_value=mock_model)
mock_model.to = MagicMock(return_value=mock_model)
mock_model.eval = MagicMock()

_wrapper = setup_inference_wrapper(mock_model, mock_tokenizer)

Expand All @@ -266,8 +295,15 @@ def test_setup_inference_wrapper_qwen3(self, mock_wrapper_cls, mock_tokenizer):
assert inference_config.hidden_size == 2048

def test_setup_inference_wrapper_invalid(self, mock_tokenizer):
mock_model = MagicMock()
# Create a simple object without module attribute to avoid infinite loop
class MockObject:
pass

mock_model = MockObject()
mock_model.config = MagicMock() # Not Qwen config
mock_model.cuda = MagicMock(return_value=mock_model)
mock_model.to = MagicMock(return_value=mock_model)
mock_model.eval = MagicMock()

with pytest.raises(ValueError):
setup_inference_wrapper(mock_model, mock_tokenizer)
Expand Down
Loading