Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForCausalLM,
patch_hf_runner=model_utils.paddleocr_vl_patch_hf_runner,
image_size_factors=[(0.25,)],
marks=[
pytest.mark.skipif(
Expand Down
25 changes: 25 additions & 0 deletions tests/models/multimodal/generation/vlm_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,31 @@ def processor(*args, text="", images=None, videos=None, **kwargs):
return hf_model


def paddleocr_vl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches the HfRunner to fix create_causal_mask API mismatch.

The PaddleOCR-VL HF model passes `inputs_embeds` to create_causal_mask,
but transformers renamed this parameter to `input_embeds`.
"""
import sys

model_module = sys.modules.get(type(hf_model.model.model).__module__)
if model_module is None:
return hf_model

original_create_causal_mask = getattr(model_module, "create_causal_mask", None)
if original_create_causal_mask is None:
return hf_model

def patched_create_causal_mask(*args, **kwargs):
if "inputs_embeds" in kwargs:
kwargs["input_embeds"] = kwargs.pop("inputs_embeds")
return original_create_causal_mask(*args, **kwargs)

model_module.create_causal_mask = patched_create_causal_mask # type: ignore[attr-defined]
return hf_model
Comment on lines +1152 to +1174
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This monkey-patch modifies a module in sys.modules, which is a global state. This can break test isolation and cause side effects in other tests running in the same process. For instance, a different test might rely on the original behavior of create_causal_mask and would fail in an unexpected way.

To ensure test reliability, this patch should be reverted after the test finishes. The standard way to achieve this is by using a context manager or a pytest fixture that handles the setup and teardown of the patch.

A possible approach would be to refactor this into a context manager that can be used within the test execution logic:

from contextlib import contextmanager

@contextmanager
def patched_create_causal_mask_for_paddleocr_vl(hf_model: HfRunner):
    import sys
    model_module = sys.modules.get(type(hf_model.model.model).__module__)
    
    original_create_causal_mask = getattr(model_module, "create_causal_mask", None)
    if original_create_causal_mask is None:
        yield
        return

    def patched_create_causal_mask(*args, **kwargs):
        if "inputs_embeds" in kwargs:
            kwargs["input_embeds"] = kwargs.pop("inputs_embeds")
        return original_create_causal_mask(*args, **kwargs)

    model_module.create_causal_mask = patched_create_causal_mask
    try:
        yield
    finally:
        model_module.create_causal_mask = original_create_causal_mask

This would require changes in the test runner to use the context manager, but it would make the test suite more robust.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global state is technically valid but not practically relevant here. The module itself is test-specific and only renames inputs_embeds.



def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner for Qwen2.5-Omni."""
thinker = hf_model.model.thinker
Expand Down
Loading