diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index 4cfdc1b2e7f6..4772ef4c9664 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -14,6 +14,7 @@ import torch import vllm.model_executor.layers.activation +from vllm.compilation.backends import VllmBackend from vllm.compilation.caching import ( StandaloneCompiledArtifacts, VllmSerializableFunction, @@ -721,3 +722,44 @@ def test_deduplication(self): ("mod3", "shape3"), ]: assert cache.get(submod, shape) == shared_data + + def test_functorch_config(self): + vllm_config = make_vllm_config() + example_inputs = (torch.randn(10, 10),) + + def add_1(x: torch.Tensor): + return x + 1 + + gm = torch._dynamo.functional_export.dynamo_graph_capture_for_export(add_1)( + *example_inputs + ) + + gm.graph._codegen = torch.fx.graph.CodeGen() + gm._dynamo_bytecode_flatten = None + gm._dynamo_bytecode_unflatten = None + + with ( + torch._functorch.config.patch(bundled_autograd_cache=False), + set_current_vllm_config(vllm_config), + ): + with torch._functorch.config.patch(bundled_autograd_cache=True): + fn = VllmSerializableFunction(gm, example_inputs, "", add_1) + + payload = VllmSerializableFunction.serialize_compile_artifacts(fn) + + config = None + + def backend(*args, **kwargs) -> VllmSerializableFunction: + nonlocal config + # bundled_autograd_cache should be True even compiler backend + # runs with bundled_autograd_cache=False in ambient context. + config = torch._functorch.config.save_config_portable() + return fn + + loaded_fn = VllmSerializableFunction.deserialize_compile_artifacts(payload) + with patch.object(VllmBackend, "__call__", backend): + loaded_fn(*example_inputs) + + assert isinstance(config, dict) + assert "bundled_autograd_cache" in config + assert config["bundled_autograd_cache"] is True diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 7f3a844a5905..3eda948b693f 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -178,6 +178,7 @@ def __init__( is_encoder: bool = False, vllm_backend: Any | None = None, sym_tensor_indices: list[int] | None = None, + aot_autograd_config: dict[str, Any] | None = None, ) -> None: assert isinstance(graph_module, torch.fx.GraphModule) self.graph_module = graph_module @@ -188,6 +189,13 @@ def __init__( self.shape_env = None self.vllm_backend = vllm_backend self.sym_tensor_indices = sym_tensor_indices + + import torch._functorch.config as functorch_config + + self.aot_autograd_config = ( + aot_autograd_config or functorch_config.save_config_portable() + ) + sym_input = next( (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None ) @@ -286,6 +294,12 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction sym_shape_indices_map = state.pop("sym_shape_indices_map", {}) returns_tuple_map = state.pop("returns_tuple_map", {}) + saved_aot_autograd_config = state["aot_autograd_config"] + if saved_aot_autograd_config is not None: + functorch_ctx = torch._functorch.config.patch(saved_aot_autograd_config) + else: + functorch_ctx = contextlib.nullcontext() + if envs.VLLM_USE_MEGA_AOT_ARTIFACT: assert standalone_compile_artifacts is not None submod_names = standalone_compile_artifacts.submodule_names() @@ -299,13 +313,14 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction num_submods, ) - fn = reconstruct_serializable_fn_from_mega_artifact( - state=state, - standalone_compile_artifacts=standalone_compile_artifacts, - vllm_config=get_current_vllm_config(), - sym_shape_indices_map=sym_shape_indices_map, - returns_tuple_map=returns_tuple_map, - ) + with functorch_ctx: + fn = reconstruct_serializable_fn_from_mega_artifact( + state=state, + standalone_compile_artifacts=standalone_compile_artifacts, + vllm_config=get_current_vllm_config(), + sym_shape_indices_map=sym_shape_indices_map, + returns_tuple_map=returns_tuple_map, + ) logger.info( "reconstructed serializable fn from standalone compile artifacts" @@ -328,7 +343,7 @@ def optimized_call(*example_inputs: Any) -> Any: vllm_backend: VllmBackend = VllmBackend( vllm_config, state["prefix"], is_encoder ) - with tracing(TracingContext(fake_mode)): + with tracing(TracingContext(fake_mode)), functorch_ctx: fn.optimized_call = vllm_backend( state["graph_module"], compile_inputs ).optimized_call