diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 0946fa69171b..da91b39e11f5 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -290,9 +290,20 @@ def __init__( return compilation_counter.num_models_seen += 1 - TorchCompileWrapperWithCustomDispatcher.__init__( - self, compilation_mode=vllm_config.compilation_config.mode - ) + if not hasattr(self.__class__, "compiled_callable"): + # only compile the same model once + # NOTE: this is probably not right, since parameters can change + # and cause us to fall over + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_mode=vllm_config.compilation_config.mode + ) + self.__class__.compiled_callable = self.compiled_callable + else: + TorchCompileWrapperWithCustomDispatcher.__init__( + self, + self.__class__.compiled_callable, + compilation_mode=vllm_config.compilation_config.mode, + ) cls.__init__ = __init__ @@ -363,7 +374,7 @@ def __call__(self, *args, **kwargs): return self.aot_compiled_fn(self, *args, **kwargs) # the first compilation needs to have dynamic shapes marked - if len(self.compiled_codes) < 1: + if len(self.__class__.compiled_codes) < 1: sig = inspect.signature(self.__class__.forward) bound_args = sig.bind(self, *args, **kwargs) bound_args.apply_defaults() @@ -403,7 +414,7 @@ def __call__(self, *args, **kwargs): # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, # with the overhead of guard evaluation and recompilation. - if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + if len(self.__class__.compiled_codes) < 1 or not self.use_custom_dispatcher: # it seems Dynamo reuse the compilation across instances, # while we need to make sure the compiled code is not reused. # we need to control all the compilation of the model. @@ -451,13 +462,16 @@ def patched_inline_call(self_): _torch27_patch_tensor_subclasses(), ): if envs.VLLM_USE_AOT_COMPILE: - self.aot_compiled_fn = self.aot_compile(*args, **kwargs) - output = self.aot_compiled_fn(self, *args, **kwargs) + if not hasattr(self.__class__, "aot_compiled_fn"): + self.__class__.aot_compiled_fn = self.aot_compile( + *args, **kwargs + ) + output = self.__class__.aot_compiled_fn(self, *args, **kwargs) assert aot_compilation_path is not None assert cache_dir is not None try: os.makedirs(cache_dir, exist_ok=True) - self.aot_compiled_fn.save_compiled_function( + self.__class__.aot_compiled_fn.save_compiled_function( aot_compilation_path ) except Exception as e: diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4b10c85209f6..6c6ced8ee9eb 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -65,7 +65,7 @@ def __init__( self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ - self.compiled_codes: list[CodeType] = [] + self.__class__.compiled_codes = [] # type: ignore[attr-defined] torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) # read the env var to determine whether to use the custom dispatcher @@ -112,7 +112,9 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): if frame.f_locals["self"] is not self: return - self.compiled_codes.append(new_code) + self.__class__.compiled_codes.append( # type: ignore[attr-defined] + new_code + ) path = self.vllm_config.compile_debug_dump_path() if path: @@ -161,6 +163,8 @@ def dispatch_to_code(self, index: int): See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. """ - self.__class__.forward.__code__ = self.compiled_codes[index] + self.__class__.forward.__code__ = self.__class__.compiled_codes[ # type: ignore[attr-defined] + index + ] yield self.__class__.forward.__code__ = self.original_code_object