Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/compile/test_aot_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,37 @@ def test_partition_wrapper_applied_on_aot_load(
)


@create_new_process_for_each_test("spawn")
def test_standalone_compile_correctness():
"""Outputs must match regardless of VLLM_USE_STANDALONE_COMPILE."""
import json

from ..utils import compare_two_settings

compilation_config = json.dumps(
{
"mode": CompilationMode.VLLM_COMPILE,
}
)

common_args = [
"--dtype",
"float16",
"--max-model-len",
"256",
"--compilation_config",
compilation_config,
]

compare_two_settings(
"facebook/opt-125m",
common_args,
common_args,
env1={"VLLM_USE_STANDALONE_COMPILE": "1"},
env2={"VLLM_USE_STANDALONE_COMPILE": "0"},
)


@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
@create_new_process_for_each_test("spawn")
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
Expand Down
17 changes: 17 additions & 0 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,23 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
)
stack.enter_context(_patch_constrain_to_fx_strides())

# Clear the tracing context before calling compile_fx.
# vLLM calls compile_fx from within a PiecewiseCompileInterpreter
# that runs under Dynamo's tracing context. The tracing context
# has a FakeTensorMode from Dynamo, but the example inputs for
# this subgraph have fake tensors from a different FakeTensorMode.
# compile_fx's _compile_fx_main calls detect_fake_mode() which
# asserts all FakeTensorModes match, causing a crash.
# Clearing the tracing context lets compile_fx create its own.
saved_tracing_context = torch._guards.TracingContext.try_get()
if saved_tracing_context is not None:
torch._guards._TLS.tracing_context = None

def _restore_tracing_context():
torch._guards._TLS.tracing_context = saved_tracing_context

stack.callback(_restore_tracing_context)
Comment on lines +643 to +650
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This fix relies on torch._guards._TLS.tracing_context, which is a private, undocumented PyTorch API. This makes the code fragile and likely to break in future PyTorch versions.

To improve maintainability, please add a comment to this block explaining that this uses a private PyTorch API and could break in the future. It would also be beneficial to file an issue with PyTorch to request a public API for this functionality and link it in the comment.

References
  1. Avoid using private, undocumented APIs from third-party libraries, as they are not part of the public contract and can change or be removed without notice, leading to future breakage.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're going to deprecate and delete this path (USE_STANDALONE_COMPILE=0) so I'm not worried about it


compiled_graph = compile_fx(
graph,
example_inputs,
Expand Down
Loading