Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,20 @@ def __init__(
return

compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_mode=vllm_config.compilation_config.mode
)
if not hasattr(self.__class__, "compiled_callable"):
# only compile the same model once
# NOTE: this is probably not right, since parameters can change
# and cause us to fall over
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_mode=vllm_config.compilation_config.mode
)
self.__class__.compiled_callable = self.compiled_callable
else:
TorchCompileWrapperWithCustomDispatcher.__init__(
self,
self.__class__.compiled_callable,
compilation_mode=vllm_config.compilation_config.mode,
)

cls.__init__ = __init__

Expand Down Expand Up @@ -363,7 +374,7 @@ def __call__(self, *args, **kwargs):
return self.aot_compiled_fn(self, *args, **kwargs)

# the first compilation needs to have dynamic shapes marked
if len(self.compiled_codes) < 1:
if len(self.__class__.compiled_codes) < 1:
sig = inspect.signature(self.__class__.forward)
bound_args = sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
Expand Down Expand Up @@ -403,7 +414,7 @@ def __call__(self, *args, **kwargs):
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
# with the overhead of guard evaluation and recompilation.
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
if len(self.__class__.compiled_codes) < 1 or not self.use_custom_dispatcher:
# it seems Dynamo reuse the compilation across instances,
# while we need to make sure the compiled code is not reused.
# we need to control all the compilation of the model.
Expand Down Expand Up @@ -451,13 +462,16 @@ def patched_inline_call(self_):
_torch27_patch_tensor_subclasses(),
):
if envs.VLLM_USE_AOT_COMPILE:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
if not hasattr(self.__class__, "aot_compiled_fn"):
self.__class__.aot_compiled_fn = self.aot_compile(
*args, **kwargs
)
output = self.__class__.aot_compiled_fn(self, *args, **kwargs)
assert aot_compilation_path is not None
assert cache_dir is not None
try:
os.makedirs(cache_dir, exist_ok=True)
self.aot_compiled_fn.save_compiled_function(
self.__class__.aot_compiled_fn.save_compiled_function(
aot_compilation_path
)
except Exception as e:
Expand Down
10 changes: 7 additions & 3 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(

self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: list[CodeType] = []
self.__class__.compiled_codes = [] # type: ignore[attr-defined]
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)

# read the env var to determine whether to use the custom dispatcher
Expand Down Expand Up @@ -112,7 +112,9 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
if frame.f_locals["self"] is not self:
return

self.compiled_codes.append(new_code)
self.__class__.compiled_codes.append( # type: ignore[attr-defined]
new_code
)

path = self.vllm_config.compile_debug_dump_path()
if path:
Expand Down Expand Up @@ -161,6 +163,8 @@ def dispatch_to_code(self, index: int):
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7
for more details.
"""
self.__class__.forward.__code__ = self.compiled_codes[index]
self.__class__.forward.__code__ = self.__class__.compiled_codes[ # type: ignore[attr-defined]
index
]
yield
self.__class__.forward.__code__ = self.original_code_object