Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors

from .compiler_interface import EagerAdaptor, InductorAdaptor
from .compiler_interface import (CompilerInterface, EagerAdaptor,
InductorAdaptor, InductorStandaloneAdaptor)
from .counter import compilation_counter
from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile
Expand All @@ -26,6 +27,19 @@
logger = init_logger(__name__)


def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor:
if envs.VLLM_TEST_STANDALONE_COMPILE:
logger.info("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor()
else:
logger.info("Using InductorAdaptor")
return InductorAdaptor()
else:
logger.info("Using EagerAdaptor")
return EagerAdaptor()


class CompilerManager:
"""
A manager to manage the compilation process, including
Expand All @@ -41,11 +55,11 @@ class CompilerManager:
support int as key.
"""

def __init__(self, use_inductor: bool):
def __init__(self, compilation_config: CompilationConfig):
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
cls = InductorAdaptor if use_inductor else EagerAdaptor
self.compiler = cls()
self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)

def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)
Expand Down Expand Up @@ -123,8 +137,15 @@ def compile(self,

# no compiler cached the graph, or the cache is disabled,
# we need to compile it
if isinstance(self.compiler, InductorAdaptor):
# Let compile_fx generate a key for us
maybe_key = None
else:
maybe_key = \
f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
compiled_graph, handle = self.compiler.compile(
graph, example_inputs, additional_inductor_config, runtime_shape)
graph, example_inputs, additional_inductor_config, runtime_shape,
maybe_key)

assert compiled_graph is not None, "Failed to compile the graph"

Expand Down Expand Up @@ -336,7 +357,7 @@ def __init__(
self.compilation_config = vllm_config.compilation_config

self.compiler_manager: CompilerManager = CompilerManager(
self.compilation_config.use_inductor)
self.compilation_config)

# `torch.compile` is JIT compiled, so we don't need to
# do anything here
Expand Down
141 changes: 118 additions & 23 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def compile(
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
Comment thread
zou3519 marked this conversation as resolved.
Outdated
) -> Tuple[Optional[Callable], Optional[Any]]:
"""
Compile the graph with the given example inputs and compiler config,
Expand All @@ -71,6 +72,10 @@ def compile(
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.

`key` is required for StandaloneInductorAdapter, it specifies where to
save the compiled artifact. The compiled artifact gets saved to
`cache_dir/key`.
"""
return None, None

Expand Down Expand Up @@ -127,23 +132,108 @@ def produce_guards_expression(self, *args, **kwargs):
return ""


def get_inductor_factors() -> List[Any]:
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)

# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
return factors


class InductorStandaloneAdaptor(CompilerInterface):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a unittest for this adaptor. Feel free to do it in the following PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the following PR I will turn on InductorStandaloneAdaptor for PyTorch >= 2.8, which will make it so that it gets tested in the "torch nightly" vLLM CI.

"""
The adaptor for the Inductor compiler.
Requires PyTorch 2.8+.
This is not on by default yet, but we plan to turn it on by default for
PyTorch 2.8.

Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
"""
name = "inductor_standalone"

def compute_hash(self, vllm_config: VllmConfig) -> str:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider the none pytorch source code in the hash?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? The compute_hash function here is the same as the compute_hash function in InductorAdaptor. I can put this into a helper function for better code reuse.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline synced, it will be called by vLLM to compute the overall cache key.

factors = get_inductor_factors()
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()[:10]
return hash_str

def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
self.cache_dir = cache_dir

def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, runtime_shape)

if isinstance(runtime_shape, int):
dynamic_shapes = "from_example_inputs"
else:
dynamic_shapes = "from_tracing_context"

from torch._inductor import standalone_compile
with pass_context(runtime_shape):
compiled_graph = standalone_compile(
graph,
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config})

# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
compiled_graph.save(path=path, format="unpacked")
return compiled_graph, (key, path)

def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
path = handle[1]
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format="unpacked")
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)

def compiled_graph_wrapper(*args):
graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
Comment on lines 219 to 220
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Could you explain this comment? Would like to understand the sketchiness

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pre-existing problem, the pre-existing InductorAdaptor also has this code in it.

If a function being compiled returns a single tensor, e.g. f(x) = x.sin(), and we compile this, then Inductor is always passed a graph that returns a (Tensor,) and Inductor returns a compiled artifact that returns a (Tensor,). Dynamo is responsible for unpacking this back into a single tensor via the bytecode it generates.

vLLM takes the bytecode that Dynamo generates and turns it into some Python code that wraps the compiled artifact. However, since we also need to manually do the unpacking here, I suspect that vLLM is not doing that process correctly.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the explanation.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we also need to manually do the unpacking here, I suspect that vLLM is not doing that process correctly

what does this mean? as you mentioned, we have special handling logic for the case when the original graph returns a single tensor, and I think vLLM is correct here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this mean? as you mentioned, we have special handling logic for the case when the original graph returns a single tensor, and I think vLLM is correct here.

In torch.compile, the handling logic for what happens when the original graph returns a single tensor is in the Dynamo-produced bytecode. In vLLM, the handling logic is in the InductorAdaptor. I would expect it to be in the Dynamo-produced bytecode.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao mentioned to me that vLLM does use the Dynamo-produced bytecode directly so... this needs more investigation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's track it in the following PR? (like creating some issue?)

if returns_tuple:
return graph_output
else:
return graph_output[0]

return compiled_graph_wrapper


class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5 and 2.6.
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name = "inductor"

def compute_hash(self, vllm_config: VllmConfig) -> str:
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)

# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
factors = get_inductor_factors()
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()[:10]
return hash_str
Expand All @@ -168,23 +258,19 @@ def compile(
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
current_config = {}
from torch._inductor.compile_fx import compile_fx
Comment thread
zou3519 marked this conversation as resolved.
Outdated
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)

# disable remote cache
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False

if compiler_config is not None:
current_config.update(compiler_config)

if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config["max_autotune"] = True
current_config["coordinate_descent_tuning"] = True
set_inductor_config(current_config, runtime_shape)

# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
Expand Down Expand Up @@ -422,6 +508,14 @@ def metrics_context(self) -> contextlib.AbstractContextManager:
return contextlib.nullcontext()


def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = True
config["coordinate_descent_tuning"] = True


class EagerAdaptor(CompilerInterface):
name = "eager"

Expand All @@ -430,7 +524,8 @@ def compile(
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),

# Internal flag to enable/disable Inductor standalone compile
"VLLM_TEST_STANDALONE_COMPILE":
lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0",

# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
Expand Down Expand Up @@ -789,6 +793,7 @@ def factorize(name: str):
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",
"VLLM_DP_SIZE",
"VLLM_TEST_STANDALONE_COMPILE",
]
for key in environment_variables_to_hash:
if key in environment_variables:
Expand Down