Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/compile/test_aot_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
StandaloneCompiledArtifacts,
VllmSerializableFunction,
)
from vllm.compilation.compiler_interface import (
_patch_standalone_compile_fake_tensor_mode,
)
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
Expand Down Expand Up @@ -886,3 +889,27 @@ def test_aot_counters_on_save_and_load(
),
):
CompiledMod(vllm_config=vllm_config)(*args)


def test_patch_standalone_compile_fake_tensor_mode_uses_function_globals():
fake_mode = object()

def standalone_compile_like():
return FakeTensorMode() # noqa: F821

original_fake_tensor_mode = standalone_compile_like.__globals__.get(
"FakeTensorMode")

with _patch_standalone_compile_fake_tensor_mode(
standalone_compile_like,
fake_mode,
):
assert standalone_compile_like() is fake_mode

if original_fake_tensor_mode is None:
assert "FakeTensorMode" not in standalone_compile_like.__globals__
else:
assert (
standalone_compile_like.__globals__["FakeTensorMode"]
is original_fake_tensor_mode
)
Comment on lines +894 to +915
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current test only covers the case where FakeTensorMode is not present in the function's globals. The restoration logic for when FakeTensorMode already exists is not tested, leaving a gap in test coverage for the new helper function. The else branch of the if statement is currently unreachable.

I suggest restructuring the test to explicitly cover both scenarios: when FakeTensorMode is absent and when it is present in the globals. This will ensure the patching and restoration logic of patch.dict is fully verified for this use case.

def test_patch_standalone_compile_fake_tensor_mode_uses_function_globals():
    fake_mode = object()

    def standalone_compile_like():
        return FakeTensorMode()  # noqa: F821

    # Case 1: FakeTensorMode is not in globals.
    assert "FakeTensorMode" not in standalone_compile_like.__globals__
    with _patch_standalone_compile_fake_tensor_mode(
        standalone_compile_like,
        fake_mode,
    ):
        assert standalone_compile_like() is fake_mode
    assert "FakeTensorMode" not in standalone_compile_like.__globals__

    # Case 2: FakeTensorMode is in globals.
    original_mode = object()
    standalone_compile_like.__globals__["FakeTensorMode"] = original_mode
    try:
        with _patch_standalone_compile_fake_tensor_mode(
            standalone_compile_like,
            fake_mode,
        ):
            assert standalone_compile_like() is fake_mode
        assert standalone_compile_like.__globals__["FakeTensorMode"] is original_mode
    finally:
        del standalone_compile_like.__globals__["FakeTensorMode"]

43 changes: 33 additions & 10 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import copy
import os
from collections.abc import Mapping
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any, Literal
Expand All @@ -23,6 +24,35 @@
logger = init_logger(__name__)


def _patch_standalone_compile_fake_tensor_mode(
standalone_compile: Callable[..., Any],
fake_tensor_mode: Any,
) -> Any:
"""Patch the FakeTensorMode lookup used by torch standalone_compile.

`torch._inductor.standalone_compile` can be imported either as the wrapper
function exposed from `torch._inductor.__init__` or as the function defined
inside `torch._inductor.standalone_compile`. On some Python/PyTorch
combinations, patching `torch._inductor.standalone_compile.FakeTensorMode`
resolves against the wrapper function instead of the module, which raises:

AttributeError: <function standalone_compile ...> does not have the
attribute 'FakeTensorMode'

Patch the function's globals instead, since that is where the runtime name
lookup for `FakeTensorMode(...)` actually happens.
"""

globals_dict = getattr(standalone_compile, "__globals__", None)
if not isinstance(globals_dict, Mapping):
return contextlib.nullcontext()

return patch.dict(
globals_dict,
{"FakeTensorMode": lambda *a, **kw: fake_tensor_mode},
)


class CompilerInterface:
"""
The interface for a compiler that can be used by vLLM.
Expand Down Expand Up @@ -373,16 +403,9 @@ def compile(
break

if input_fake_mode is not None:
# Use patch.object on the actual module from sys.modules
# because in Python <=3.10 the string-based patch() resolves
# torch._inductor.standalone_compile to the wrapper function
# (defined in __init__.py) instead of the module.
import sys

fake_mode_ctx: Any = patch.object(
sys.modules["torch._inductor.standalone_compile"],
"FakeTensorMode",
lambda *a, **kw: input_fake_mode,
fake_mode_ctx: Any = _patch_standalone_compile_fake_tensor_mode(
standalone_compile,
input_fake_mode,
)
else:
fake_mode_ctx = contextlib.nullcontext()
Expand Down
Loading