Skip to content

[Bugfix] Fix mock.patch resolution failure for standalone_compile.FakeTensorMode on Python <= 3.10#37158

Merged
zou3519 merged 1 commit intovllm-project:mainfrom
dbari:dbariamis/fix-faketensors-bug
Mar 17, 2026
Merged

[Bugfix] Fix mock.patch resolution failure for standalone_compile.FakeTensorMode on Python <= 3.10#37158
zou3519 merged 1 commit intovllm-project:mainfrom
dbari:dbariamis/fix-faketensors-bug

Conversation

@dbari
Copy link
Copy Markdown
Contributor

@dbari dbari commented Mar 16, 2026

Purpose

This PR fixes an AttributeError: <function standalone_compile at 0x7b8cf83c40d0> does not have the attribute 'FakeTensorMode' crash introduced by #36093 when running with Python <= 3.10.

Stacktrace
[...]
(EngineCore pid=842) INFO 03-16 08:18:10 [backends.py:1048] Dynamo bytecode transform time: 5.43 s
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099] EngineCore failed to start.
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099] Traceback (most recent call last):
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/engine/core.py", line 1073, in run_engine_core
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/engine/core.py", line 839, in __init__
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     super().__init__(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/engine/core.py", line 122, in __init__
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     kv_cache_config = self._initialize_kv_caches(vllm_config)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/engine/core.py", line 245, in _initialize_kv_caches
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     available_gpu_memory = self.model_executor.determine_available_memory()
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/executor/abstract.py", line 136, in determine_available_memory
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return self.collective_rpc("determine_available_memory")
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/executor/uniproc_executor.py", line 78, in collective_rpc
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/serial_utils.py", line 459, in run_method
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/worker/gpu_worker.py", line 388, in determine_available_memory
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     self.model_runner.profile_run()
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/worker/gpu_model_runner.py", line 5527, in profile_run
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     hidden_states, last_hidden_states = self._dummy_run(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/v1/worker/gpu_model_runner.py", line 5221, in _dummy_run
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     outputs = self.model(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/cuda_graph.py", line 241, in __call__
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return self.runnable(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return self._call_impl(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return forward_call(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/model_executor/models/pixtral.py", line 381, in forward
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     hidden_states = self.language_model.model(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/decorators.py", line 583, in __call__
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/wrapper.py", line 206, in aot_compile
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return self._compiled_callable.aot_compile((args, kwargs))
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 832, in aot_compile
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return aot_compile_fullgraph(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/_dynamo/aot_compile.py", line 239, in aot_compile_fullgraph
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     compiled_fn = backend(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/__init__.py", line 2509, in __call__
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return self.compiler_fn(model_, inputs_, **self.kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return func(*args, **kwds)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/backends.py", line 1114, in __call__
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     PiecewiseCompileInterpreter(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/backends.py", line 640, in run
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return super().run(*args)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 200, in run
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     self.env[node] = self.run_node(node)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/.venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 295, in run_node
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return getattr(self, n.op)(n.target, args, kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/backends.py", line 667, in call_module
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     piecewise_backend = PiecewiseBackend(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/piecewise_backend.py", line 189, in __init__
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     self.compile_all_ranges()
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/piecewise_backend.py", line 265, in compile_all_ranges
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     range_entry.runnable = self.vllm_backend.compiler_manager.compile(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/backends.py", line 346, in compile
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     compiled_graph, handle = self.compiler.compile(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/scratch/code/vllm/vllm/compilation/compiler_interface.py", line 383, in compile
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides():
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/usr/lib/python3.10/unittest/mock.py", line 1447, in __enter__
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     original, local = self.get_original()
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]   File "/usr/lib/python3.10/unittest/mock.py", line 1420, in get_original
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099]     raise AttributeError(
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099] AttributeError: <function standalone_compile at 0x7b8cf83c40d0> does not have the attribute 'FakeTensorMode'

The string-based patch("torch._inductor.standalone_compile.FakeTensorMode", ...) resolves torch._inductor.standalone_compile to the wrapper function defined in torch/_inductor/__init__.py rather than the module, because in Python 3.10 the mock._importer uses getattr. In Python 3.11+, mock.patch switched to pkgutil.resolve_name (cpython#18544) which correctly resolves to the module.

Test Plan

Test by loading a model before and after the fix in a Python 3.10 environment.

Test Result

Before:

[...]
(EngineCore pid=842) INFO 03-16 08:18:10 [backends.py:1048] Dynamo bytecode transform time: 5.43 s
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099] EngineCore failed to start.
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099] Traceback (most recent call last):
[...]
(EngineCore pid=842) ERROR 03-16 08:18:11 [core.py:1099] AttributeError: <function standalone_compile at 0x7b8cf83c40d0> does not have the attribute 'FakeTensorMode'

After:

[...]
(EngineCore pid=211) INFO 03-16 08:12:00 [backends.py:1048] Dynamo bytecode transform time: 17.36 s
(EngineCore pid=211) INFO 03-16 08:12:01 [backends.py:284] Directly load the compiled graph(s) for compile range (1, 8192) from the cache, took 1.586 s
(EngineCore pid=211) INFO 03-16 08:12:01 [monitor.py:48] torch.compile took 19.36 s in total
[...]
(APIServer pid=102) INFO:     Started server process [102]
(APIServer pid=102) INFO:     Waiting for application startup.
(APIServer pid=102) INFO:     Application startup complete.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify Bot added the bug Something isn't working label Mar 16, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves a critical bug that caused mock.patch to fail when resolving standalone_compile.FakeTensorMode on Python versions 3.10 and older. The issue stemmed from mock.patch incorrectly interpreting torch._inductor.standalone_compile as a wrapper function rather than the intended module. The fix, which involves using patch.object with an explicit reference to the module from sys.modules, correctly targets the FakeTensorMode attribute within the standalone_compile module. This ensures compatibility and prevents crashes in older Python environments, significantly improving the robustness of the compilation process. The added comments provide clear context for the change, enhancing code clarity and maintainability.

@dbari dbari force-pushed the dbariamis/fix-faketensors-bug branch from 3a5c132 to c68b827 Compare March 16, 2026 08:47
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 16, 2026

Hi @dbari, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@dbari dbari force-pushed the dbariamis/fix-faketensors-bug branch from c68b827 to 59a769f Compare March 16, 2026 09:34
@dbari
Copy link
Copy Markdown
Contributor Author

dbari commented Mar 17, 2026

Hi, could I please get a review for this one? Thanks!

Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, though I'm not sure how to prevent future regressions

@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 17, 2026
@zou3519 zou3519 enabled auto-merge (squash) March 17, 2026 17:35
@zou3519 zou3519 merged commit 1204cf0 into vllm-project:main Mar 17, 2026
56 checks passed
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
…eTensorMode on Python <= 3.10 (vllm-project#37158)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@dbari
Copy link
Copy Markdown
Contributor Author

dbari commented Mar 18, 2026

lgtm, though I'm not sure how to prevent future regressions

I would consider this kind of collision to be a code smell in Pytorch, but I couldn't find any linting rule that would ensure it does not happen. While it's not a problem in the 99% of cases when working through standard python import machinery, it can cause problems with manual traversing of modules, or in this case patching.

In any case, for vLLM it seems to be a one-off problem that is solved now, and Python 3.10 will be EOL in 2026-10, so in a few months patching will work for all supported python versions.

Thanks for the review and merge!

@dbari dbari deleted the dbariamis/fix-faketensors-bug branch March 18, 2026 09:16
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…eTensorMode on Python <= 3.10 (vllm-project#37158)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@duongck
Copy link
Copy Markdown

duongck commented Mar 21, 2026

I still encounter this issue:

AttributeError: <function standalone_compile at 0x73f92c77eef0> does not have the attribute 'FakeTensorMode'
[rank0]:[W321 16:02:34.801848343 ProcessGroupNCCL.cpp:1553] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

I use python 3.10.12, installing vLLM by this: pip install vllm --extra-index-url https://download.pytorch.org/whl/cu129
I also try to install with cuda 13.0, but it still throws that error.
Can anyone help me to fix it? Or I missed anything while installing?

@dbari
Copy link
Copy Markdown
Contributor Author

dbari commented Mar 23, 2026

@duongck this had fixed the problem for me. Can you post a reproducer, i.e. which vLLM git hash you're using and what command line / model? You could check if there are other instances of patching the same function, which would need to be fixed.

@zou3519
Copy link
Copy Markdown
Collaborator

zou3519 commented Mar 23, 2026

I think the problem is the vLLM v0.18.0 release didn't include this patch

khluu pushed a commit that referenced this pull request Mar 25, 2026
…eTensorMode on Python <= 3.10 (#37158)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
(cherry picked from commit 1204cf0)
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
…eTensorMode on Python <= 3.10 (vllm-project#37158)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…eTensorMode on Python <= 3.10 (vllm-project#37158)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…eTensorMode on Python <= 3.10 (vllm-project#37158)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…eTensorMode on Python <= 3.10 (vllm-project#37158)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants