Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/compile/test_aot_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

import vllm.model_executor.layers.activation
from vllm.compilation.backends import VllmBackend
from vllm.compilation.caching import (
StandaloneCompiledArtifacts,
VllmSerializableFunction,
Expand Down Expand Up @@ -721,3 +722,44 @@ def test_deduplication(self):
("mod3", "shape3"),
]:
assert cache.get(submod, shape) == shared_data

def test_functorch_config(self):
vllm_config = make_vllm_config()
example_inputs = (torch.randn(10, 10),)

def add_1(x: torch.Tensor):
return x + 1

gm = torch._dynamo.functional_export.dynamo_graph_capture_for_export(add_1)(
*example_inputs
)

gm.graph._codegen = torch.fx.graph.CodeGen()
gm._dynamo_bytecode_flatten = None
gm._dynamo_bytecode_unflatten = None

with (
torch._functorch.config.patch(bundled_autograd_cache=False),
set_current_vllm_config(vllm_config),
):
with torch._functorch.config.patch(bundled_autograd_cache=True):
fn = VllmSerializableFunction(gm, example_inputs, "", add_1)

payload = VllmSerializableFunction.serialize_compile_artifacts(fn)

config = None

def backend(*args, **kwargs) -> VllmSerializableFunction:
nonlocal config
# bundled_autograd_cache should be True even compiler backend
# runs with bundled_autograd_cache=False in ambient context.
config = torch._functorch.config.save_config_portable()
return fn

loaded_fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
with patch.object(VllmBackend, "__call__", backend):
loaded_fn(*example_inputs)

assert isinstance(config, dict)
assert "bundled_autograd_cache" in config
assert config["bundled_autograd_cache"] is True
31 changes: 23 additions & 8 deletions vllm/compilation/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
is_encoder: bool = False,
vllm_backend: Any | None = None,
sym_tensor_indices: list[int] | None = None,
aot_autograd_config: dict[str, Any] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding, do you need inductor config too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, we should also preserve inductor config. In practice, inductor config doesn't change with AOT on/off like functorch, so backend key will stay the same after #35472.

) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
Expand All @@ -188,6 +189,13 @@ def __init__(
self.shape_env = None
self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices

import torch._functorch.config as functorch_config

self.aot_autograd_config = (
aot_autograd_config or functorch_config.save_config_portable()
)

sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
)
Expand Down Expand Up @@ -286,6 +294,12 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
returns_tuple_map = state.pop("returns_tuple_map", {})

saved_aot_autograd_config = state["aot_autograd_config"]
if saved_aot_autograd_config is not None:
functorch_ctx = torch._functorch.config.patch(saved_aot_autograd_config)
else:
functorch_ctx = contextlib.nullcontext()

if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
assert standalone_compile_artifacts is not None
submod_names = standalone_compile_artifacts.submodule_names()
Expand All @@ -299,13 +313,14 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction
num_submods,
)

fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
standalone_compile_artifacts=standalone_compile_artifacts,
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
)
with functorch_ctx:
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
standalone_compile_artifacts=standalone_compile_artifacts,
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
)

logger.info(
"reconstructed serializable fn from standalone compile artifacts"
Expand All @@ -328,7 +343,7 @@ def optimized_call(*example_inputs: Any) -> Any:
vllm_backend: VllmBackend = VllmBackend(
vllm_config, state["prefix"], is_encoder
)
with tracing(TracingContext(fake_mode)):
with tracing(TracingContext(fake_mode)), functorch_ctx:
fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs
).optimized_call
Expand Down