From f67cba24b2e1f816057de8dd1d21742abbe04a4f Mon Sep 17 00:00:00 2001 From: CMLKevin Date: Mon, 23 Mar 2026 16:14:41 +0800 Subject: [PATCH] fix(compilation): patch standalone FakeTensorMode via function globals --- tests/compile/test_aot_compile.py | 27 ++++++++++++++++ vllm/compilation/compiler_interface.py | 43 ++++++++++++++++++++------ 2 files changed, 60 insertions(+), 10 deletions(-) diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index 8a5191ed226c..2ce1ab68c145 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -21,6 +21,9 @@ StandaloneCompiledArtifacts, VllmSerializableFunction, ) +from vllm.compilation.compiler_interface import ( + _patch_standalone_compile_fake_tensor_mode, +) from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import ( @@ -886,3 +889,27 @@ def test_aot_counters_on_save_and_load( ), ): CompiledMod(vllm_config=vllm_config)(*args) + + +def test_patch_standalone_compile_fake_tensor_mode_uses_function_globals(): + fake_mode = object() + + def standalone_compile_like(): + return FakeTensorMode() # noqa: F821 + + original_fake_tensor_mode = standalone_compile_like.__globals__.get( + "FakeTensorMode") + + with _patch_standalone_compile_fake_tensor_mode( + standalone_compile_like, + fake_mode, + ): + assert standalone_compile_like() is fake_mode + + if original_fake_tensor_mode is None: + assert "FakeTensorMode" not in standalone_compile_like.__globals__ + else: + assert ( + standalone_compile_like.__globals__["FakeTensorMode"] + is original_fake_tensor_mode + ) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index ac63143b0051..dcd8d33e4870 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -3,6 +3,7 @@ import contextlib import copy import os +from collections.abc import Mapping from collections.abc import Callable from contextlib import ExitStack from typing import Any, Literal @@ -23,6 +24,35 @@ logger = init_logger(__name__) +def _patch_standalone_compile_fake_tensor_mode( + standalone_compile: Callable[..., Any], + fake_tensor_mode: Any, +) -> Any: + """Patch the FakeTensorMode lookup used by torch standalone_compile. + + `torch._inductor.standalone_compile` can be imported either as the wrapper + function exposed from `torch._inductor.__init__` or as the function defined + inside `torch._inductor.standalone_compile`. On some Python/PyTorch + combinations, patching `torch._inductor.standalone_compile.FakeTensorMode` + resolves against the wrapper function instead of the module, which raises: + + AttributeError: does not have the + attribute 'FakeTensorMode' + + Patch the function's globals instead, since that is where the runtime name + lookup for `FakeTensorMode(...)` actually happens. + """ + + globals_dict = getattr(standalone_compile, "__globals__", None) + if not isinstance(globals_dict, Mapping): + return contextlib.nullcontext() + + return patch.dict( + globals_dict, + {"FakeTensorMode": lambda *a, **kw: fake_tensor_mode}, + ) + + class CompilerInterface: """ The interface for a compiler that can be used by vLLM. @@ -373,16 +403,9 @@ def compile( break if input_fake_mode is not None: - # Use patch.object on the actual module from sys.modules - # because in Python <=3.10 the string-based patch() resolves - # torch._inductor.standalone_compile to the wrapper function - # (defined in __init__.py) instead of the module. - import sys - - fake_mode_ctx: Any = patch.object( - sys.modules["torch._inductor.standalone_compile"], - "FakeTensorMode", - lambda *a, **kw: input_fake_mode, + fake_mode_ctx: Any = _patch_standalone_compile_fake_tensor_mode( + standalone_compile, + input_fake_mode, ) else: fake_mode_ctx = contextlib.nullcontext()