From ed6386aefdd0ffad9c8d2913899f91662a18e77c Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Fri, 13 Mar 2026 13:42:57 -0700 Subject: [PATCH] [BugFix] fix VLLM_USE_STANDALONE_COMPILE=0 I broke this in one of the refactors, this fixes it and adds some testing Signed-off-by: Richard Zou --- tests/compile/test_aot_compile.py | 31 ++++++++++++++++++++++++++ vllm/compilation/compiler_interface.py | 17 ++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index 8a5191ed226c..c3a065c56142 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -441,6 +441,37 @@ def test_partition_wrapper_applied_on_aot_load( ) +@create_new_process_for_each_test("spawn") +def test_standalone_compile_correctness(): + """Outputs must match regardless of VLLM_USE_STANDALONE_COMPILE.""" + import json + + from ..utils import compare_two_settings + + compilation_config = json.dumps( + { + "mode": CompilationMode.VLLM_COMPILE, + } + ) + + common_args = [ + "--dtype", + "float16", + "--max-model-len", + "256", + "--compilation_config", + compilation_config, + ] + + compare_two_settings( + "facebook/opt-125m", + common_args, + common_args, + env1={"VLLM_USE_STANDALONE_COMPILE": "1"}, + env2={"VLLM_USE_STANDALONE_COMPILE": "0"}, + ) + + @pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") @create_new_process_for_each_test("spawn") def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index ac63143b0051..bddacfbbc295 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -632,6 +632,23 @@ def _get_shape_env() -> AlwaysHitShapeEnv: ) stack.enter_context(_patch_constrain_to_fx_strides()) + # Clear the tracing context before calling compile_fx. + # vLLM calls compile_fx from within a PiecewiseCompileInterpreter + # that runs under Dynamo's tracing context. The tracing context + # has a FakeTensorMode from Dynamo, but the example inputs for + # this subgraph have fake tensors from a different FakeTensorMode. + # compile_fx's _compile_fx_main calls detect_fake_mode() which + # asserts all FakeTensorModes match, causing a crash. + # Clearing the tracing context lets compile_fx create its own. + saved_tracing_context = torch._guards.TracingContext.try_get() + if saved_tracing_context is not None: + torch._guards._TLS.tracing_context = None + + def _restore_tracing_context(): + torch._guards._TLS.tracing_context = saved_tracing_context + + stack.callback(_restore_tracing_context) + compiled_graph = compile_fx( graph, example_inputs,