forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torch.compile] avoid Dynamo guard evaluation overhead (vllm-project#…
…7898) Co-authored-by: Woosuk Kwon <[email protected]>
- Loading branch information
Showing
9 changed files
with
190 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther | ||
|
||
|
||
class MyMod(torch.nn.Module): | ||
|
||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): | ||
if cache is not None: | ||
return x + cache | ||
return x * 2 | ||
|
||
|
||
class MyWrapper(TorchCompileWrapperWithCustomDispacther): | ||
|
||
def __init__(self, model): | ||
self.model = model | ||
compiled_callable = torch.compile(self.forward, backend="eager") | ||
super().__init__(compiled_callable) | ||
|
||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): | ||
# this is the function to be compiled | ||
return self.model(x, cache) | ||
|
||
def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): | ||
# let torch.compile compile twice | ||
if len(self.compiled_codes) == 2: | ||
dispatch_id = 0 if cache is None else 1 | ||
with self.dispatch_to_code(dispatch_id): | ||
return self.forward(x, cache) | ||
else: | ||
return self.compiled_callable(x, cache) | ||
|
||
|
||
def test_torch_compile_wrapper(): | ||
mod = MyMod() | ||
wrappers = [] | ||
for i in range(3): | ||
torch._dynamo.reset() | ||
wrapper = MyWrapper(mod) | ||
wrappers.append(wrapper) | ||
x = torch.tensor([1]) | ||
wrapper(x, None) # profile run, compile | ||
# create a cache tensor | ||
cache = torch.tensor([2]) | ||
wrapper(x, cache) # warm up with cache, recompile | ||
|
||
# for new input, dispatch to the compiled code directly | ||
new_x = torch.tensor([3]) | ||
assert wrapper(new_x, | ||
None).item() == 6 # dispatch to the first compiled code | ||
assert wrapper( | ||
new_x, cache).item() == 5 # dispatch to the second compiled code | ||
|
||
for wrapper in wrappers: | ||
# make sure they have independent compiled codes | ||
assert len(wrapper.compiled_codes) == 2 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from ..utils import compare_two_settings | ||
|
||
|
||
def test_custom_dispatcher(): | ||
compare_two_settings("google/gemma-2b", | ||
arg1=["--enforce-eager"], | ||
arg2=["--enforce-eager"], | ||
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}, | ||
env2={}) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import os | ||
import sys | ||
from abc import abstractmethod | ||
from contextlib import contextmanager | ||
from types import CodeType | ||
from typing import Callable, List | ||
|
||
import torch | ||
|
||
import vllm.envs as envs | ||
|
||
|
||
class TorchCompileWrapperWithCustomDispacther: | ||
""" | ||
A wrapper class for torch.compile, with a custom dispatch logic. | ||
Subclasses should: | ||
1. Implement the forward method | ||
2. Implement the dispatch logic in the __call__ method | ||
It can use `self.compiled_codes` to access the compiled bytecode, | ||
and `with self.dispatch_to_code(index):` to dispatch to | ||
the compiled code. | ||
3. Implement the `__init__` method to determine how to call | ||
`torch.compile` over the forward method. | ||
""" | ||
|
||
def __init__(self, compiled_callable: Callable): | ||
self.compiled_callable = compiled_callable | ||
self.original_code_object = self.__class__.forward.__code__ | ||
self.compiled_codes: List[CodeType] = [] | ||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) | ||
|
||
# read the env var to determine whether to use the custom dispatcher | ||
# subclasses can use this to switch between the custom dispatcher | ||
# and the default Dynamo guard mechanism. | ||
self.use_custom_dispatcher: bool = \ | ||
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER | ||
|
||
def __call__(self, *args, **kwargs): | ||
"""Implement the dispatch logic here, beyond the torch.compile level. | ||
NOTE: this function can have additional arguments beyond the forward | ||
method, for directly dispatching to the compiled code. | ||
""" | ||
return self.compiled_callable(*args, **kwargs) | ||
|
||
@abstractmethod | ||
def forward(self, *args, **kwargs): | ||
... | ||
|
||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType): | ||
"""Hook to save the compiled bytecode for direct execution.""" | ||
if old_code is not self.original_code_object: | ||
return | ||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 | ||
frame = sys._getframe() | ||
while True: | ||
frame = frame.f_back | ||
code_name = frame.f_code.co_name | ||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1] | ||
if code_name == "_compile" and file_name == "convert_frame.py": | ||
break | ||
frame = frame.f_locals["frame"] | ||
assert frame.f_code == old_code | ||
|
||
if frame.f_locals["self"] is not self: | ||
return | ||
|
||
self.compiled_codes.append(new_code) | ||
|
||
@contextmanager | ||
def dispatch_to_code(self, index: int): | ||
"""Context manager to dispatch to the compiled code. | ||
Why does this work? Because Dynamo guarantees that the compiled | ||
bytecode has exactly the same arguments, cell variables, and free | ||
variables as the original code. Therefore we can directly switch | ||
the code object in the function and call it. | ||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. | ||
""" # noqa | ||
self.__class__.forward.__code__ = self.compiled_codes[index] | ||
yield | ||
self.__class__.forward.__code__ = self.original_code_object |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters