fix: handle FakeTensorMode patching for PyTorch compatibility#37866
fix: handle FakeTensorMode patching for PyTorch compatibility#37866CMLKevin wants to merge 1 commit intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a more robust way to patch FakeTensorMode for PyTorch compatibility by modifying the function's __globals__ directly, which is a solid approach to handle inconsistencies across different environments. The change is well-implemented and includes a new unit test. My review includes a suggestion to enhance the new test to cover all intended scenarios, ensuring the restoration logic is also verified.
| def test_patch_standalone_compile_fake_tensor_mode_uses_function_globals(): | ||
| fake_mode = object() | ||
|
|
||
| def standalone_compile_like(): | ||
| return FakeTensorMode() # noqa: F821 | ||
|
|
||
| original_fake_tensor_mode = standalone_compile_like.__globals__.get( | ||
| "FakeTensorMode") | ||
|
|
||
| with _patch_standalone_compile_fake_tensor_mode( | ||
| standalone_compile_like, | ||
| fake_mode, | ||
| ): | ||
| assert standalone_compile_like() is fake_mode | ||
|
|
||
| if original_fake_tensor_mode is None: | ||
| assert "FakeTensorMode" not in standalone_compile_like.__globals__ | ||
| else: | ||
| assert ( | ||
| standalone_compile_like.__globals__["FakeTensorMode"] | ||
| is original_fake_tensor_mode | ||
| ) |
There was a problem hiding this comment.
The current test only covers the case where FakeTensorMode is not present in the function's globals. The restoration logic for when FakeTensorMode already exists is not tested, leaving a gap in test coverage for the new helper function. The else branch of the if statement is currently unreachable.
I suggest restructuring the test to explicitly cover both scenarios: when FakeTensorMode is absent and when it is present in the globals. This will ensure the patching and restoration logic of patch.dict is fully verified for this use case.
def test_patch_standalone_compile_fake_tensor_mode_uses_function_globals():
fake_mode = object()
def standalone_compile_like():
return FakeTensorMode() # noqa: F821
# Case 1: FakeTensorMode is not in globals.
assert "FakeTensorMode" not in standalone_compile_like.__globals__
with _patch_standalone_compile_fake_tensor_mode(
standalone_compile_like,
fake_mode,
):
assert standalone_compile_like() is fake_mode
assert "FakeTensorMode" not in standalone_compile_like.__globals__
# Case 2: FakeTensorMode is in globals.
original_mode = object()
standalone_compile_like.__globals__["FakeTensorMode"] = original_mode
try:
with _patch_standalone_compile_fake_tensor_mode(
standalone_compile_like,
fake_mode,
):
assert standalone_compile_like() is fake_mode
assert standalone_compile_like.__globals__["FakeTensorMode"] is original_mode
finally:
del standalone_compile_like.__globals__["FakeTensorMode"]
Fixes #37858
The current patch assumes
torch._inductor.standalone_compileresolves to the standalone_compile module and patchesFakeTensorModethere. In the failing PyTorch/Python combination from the issue, that lookup ends up on the wrapper function instead, which raisesAttributeError: <function standalone_compile ...> does not have the attribute 'FakeTensorMode'before compilation starts.This switches the override to patch the imported function's globals instead, which is where
FakeTensorMode(...)is actually resolved at runtime. I also added a small regression test covering the function-global patch path.Verification here was limited to
py_compileplus a direct smoke test of the helper because the repo test harness importstblib, which is missing in this environment.