From 248d4dbd0f52a42319cb329af154c65f89507fd9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 26 Aug 2024 22:47:01 -0700 Subject: [PATCH 01/22] custom dispatch --- vllm/worker/tpu_model_runner.py | 41 ++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 01daa64b5a32f..992f0fd50b6b0 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,5 +1,6 @@ import time from dataclasses import dataclass +from types import CodeType from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from unittest.mock import patch @@ -144,11 +145,7 @@ def load_model(self) -> None: ) model = model.eval() xm.wait_device_ops() - model = ModelWrapper(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) + self.model = ModelWrapper(model) def _dummy_run( self, @@ -530,7 +527,7 @@ def _execute_model(*args): if getattr(arg, "context_lens", None) is not None: arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) - return self.model(*new_args) + return self.model(*new_args, is_prompt=is_prompt) num_prefills = model_input.attn_metadata.num_prefills is_prompt = num_prefills > 0 @@ -601,12 +598,23 @@ def _execute_model(*args): return [SamplerOutput(sampler_outputs)] -class ModelWrapper(nn.Module): +class ModelWrapper: def __init__(self, model: nn.Module): - super().__init__() self.model = model + def __call__(self, *args, is_prompt: bool = False, **kwargs): + if len(ModelWrapper.compiled_codes) < 3: + # not fully compiled yet, let PyTorch handle it + return self.compiled_forward(*args, **kwargs) + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + ModelWrapper.forward.__code__ = ModelWrapper.compiled_codes[1] + return self.forward(*args, **kwargs) + else: + ModelWrapper.forward.__code__ = ModelWrapper.compiled_codes[2] + return self.forward(*args, **kwargs) + def forward( self, token_ids: torch.Tensor, @@ -695,6 +703,23 @@ def forward( argmax_token_ids) return next_token_ids + compiled_forward = torch.compile(forward, + backend="openxla", + fullgraph=True, + dynamic=False) + + target_code = forward.__code__ + compiled_codes: List[CodeType] = [] + + @staticmethod + def collect_bytecode_hook(old, new): + global compiled_codes + if old is ModelWrapper.target_code: + ModelWrapper.compiled_codes.append(new) + print(ModelWrapper.compiled_codes) + + torch._dynamo.convert_frame.register_bytecode_hook(collect_bytecode_hook) + def _get_padded_prefill_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence From 8f4ed39374c4afa55a31183cfa6f83f902d1b351 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 26 Aug 2024 23:02:28 -0700 Subject: [PATCH 02/22] refine --- vllm/worker/tpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 992f0fd50b6b0..e81dbbd4aad3a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -607,6 +607,10 @@ def __call__(self, *args, is_prompt: bool = False, **kwargs): if len(ModelWrapper.compiled_codes) < 3: # not fully compiled yet, let PyTorch handle it return self.compiled_forward(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode # dispatch to the compiled code directly, skip PyTorch if is_prompt: ModelWrapper.forward.__code__ = ModelWrapper.compiled_codes[1] @@ -716,7 +720,6 @@ def collect_bytecode_hook(old, new): global compiled_codes if old is ModelWrapper.target_code: ModelWrapper.compiled_codes.append(new) - print(ModelWrapper.compiled_codes) torch._dynamo.convert_frame.register_bytecode_hook(collect_bytecode_hook) From 2d8b20a050945856659275f75954f550ddb43d17 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 00:57:43 -0700 Subject: [PATCH 03/22] add wrapper --- vllm/compilation/wrapper.py | 66 +++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 vllm/compilation/wrapper.py diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py new file mode 100644 index 0000000000000..b1f6c7286fd2c --- /dev/null +++ b/vllm/compilation/wrapper.py @@ -0,0 +1,66 @@ +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Any, Callable, Dict, List, Optional + + +class TorchCompileWrapperWithCustomDispacther: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, compiled_callable: Callable): + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: List[CodeType] = [] + + def __call__(self, + *args, + dispatch_args: Optional[Dict[str, Any]] = None, + **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level, + according to the dispatch_args, which is not visible to torch.compile. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + + # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 + frame = sys._getframe() + while True: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code.""" + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object From 9f752fde72c4bd405162e40baeece7b5fee539ce Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 09:02:11 -0700 Subject: [PATCH 04/22] update --- vllm/compilation/wrapper.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index b1f6c7286fd2c..4c390e8ed5ded 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -3,7 +3,7 @@ from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, List class TorchCompileWrapperWithCustomDispacther: @@ -24,12 +24,10 @@ def __init__(self, compiled_callable: Callable): self.original_code_object = self.__class__.forward.__code__ self.compiled_codes: List[CodeType] = [] - def __call__(self, - *args, - dispatch_args: Optional[Dict[str, Any]] = None, - **kwargs): - """Implement the dispatch logic here, beyond the torch.compile level, - according to the dispatch_args, which is not visible to torch.compile. + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. """ return self.compiled_callable(*args, **kwargs) From 4be616aa852b996a6e56f83a3158cda3dc913bd7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 09:20:27 -0700 Subject: [PATCH 05/22] add wrapper test --- tests/compile/test_full_graph.py | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index d5b59db8c7887..b14c8cebb22e0 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,6 +1,45 @@ import os import pytest +import torch + +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther + + +class MyMod(torch.nn.Module): + + def forward(self, x: torch.Tensor, a: int): + return (x + a) * 2 + + +class MyWrapper(TorchCompileWrapperWithCustomDispacther): + + def __init__(self, model): + self.model = model + compiled_callable = torch.compile(self.forward, backend="eager") + super().__init__(compiled_callable) + + def forward(self, x: torch.Tensor, a: int): + # this is the function to be compiled + return self.model(x, a) + + def __call__(self, x, a): + # let torch.compile compile twice + if len(self.compiled_codes) >= 2: + with self.dispatch_to_code(0): + return self.compiled_callable(x, a) + else: + return self.compiled_callable(x) + + +def test_torch_compile_wrapper(): + mod = MyMod() + wrapper = MyWrapper(mod) + x = torch.tensor([1.0]) + wrapper(x, 0) # first time, compile + wrapper(x, 1) # second time, compile + wrapper(x, 2) # third time, dispatch to the first compiled code + assert len(wrapper.compiled_codes) == 2 @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) From 026a52570719fcd212383f2db660bdd90205b23f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 09:24:13 -0700 Subject: [PATCH 06/22] fix --- tests/compile/test_full_graph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index b14c8cebb22e0..7e61a20fa77bc 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -23,13 +23,13 @@ def forward(self, x: torch.Tensor, a: int): # this is the function to be compiled return self.model(x, a) - def __call__(self, x, a): + def __call__(self, x: torch.Tensor, a: int): # let torch.compile compile twice if len(self.compiled_codes) >= 2: with self.dispatch_to_code(0): - return self.compiled_callable(x, a) + return self.forward(x, a) else: - return self.compiled_callable(x) + return self.compiled_callable(x, a) def test_torch_compile_wrapper(): From 7a1dd3873380c348593c569afa1609ecfd44a56c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 09:29:47 -0700 Subject: [PATCH 07/22] update wrapper --- vllm/compilation/wrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4c390e8ed5ded..1d2c1111a222e 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -5,6 +5,8 @@ from types import CodeType from typing import Callable, List +import torch + class TorchCompileWrapperWithCustomDispacther: """ @@ -23,6 +25,7 @@ def __init__(self, compiled_callable: Callable): self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ self.compiled_codes: List[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. @@ -39,7 +42,6 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): """Hook to save the compiled bytecode for direct execution.""" if old_code is not self.original_code_object: return - # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 frame = sys._getframe() while True: From 1f0f148346ee3a73b9ac32032283bb3355b4facd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 09:30:42 -0700 Subject: [PATCH 08/22] separate tests --- tests/compile/test_full_graph.py | 39 ------------------------------- tests/compile/test_wrapper.py | 40 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 39 deletions(-) create mode 100644 tests/compile/test_wrapper.py diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 7e61a20fa77bc..d5b59db8c7887 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,45 +1,6 @@ import os import pytest -import torch - -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther - - -class MyMod(torch.nn.Module): - - def forward(self, x: torch.Tensor, a: int): - return (x + a) * 2 - - -class MyWrapper(TorchCompileWrapperWithCustomDispacther): - - def __init__(self, model): - self.model = model - compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable) - - def forward(self, x: torch.Tensor, a: int): - # this is the function to be compiled - return self.model(x, a) - - def __call__(self, x: torch.Tensor, a: int): - # let torch.compile compile twice - if len(self.compiled_codes) >= 2: - with self.dispatch_to_code(0): - return self.forward(x, a) - else: - return self.compiled_callable(x, a) - - -def test_torch_compile_wrapper(): - mod = MyMod() - wrapper = MyWrapper(mod) - x = torch.tensor([1.0]) - wrapper(x, 0) # first time, compile - wrapper(x, 1) # second time, compile - wrapper(x, 2) # third time, dispatch to the first compiled code - assert len(wrapper.compiled_codes) == 2 @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py new file mode 100644 index 0000000000000..60b9b4d43acdd --- /dev/null +++ b/tests/compile/test_wrapper.py @@ -0,0 +1,40 @@ +import torch + +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther + + +class MyMod(torch.nn.Module): + + def forward(self, x: torch.Tensor, a: int): + return (x + a) * 2 + + +class MyWrapper(TorchCompileWrapperWithCustomDispacther): + + def __init__(self, model): + self.model = model + compiled_callable = torch.compile(self.forward, backend="eager") + super().__init__(compiled_callable) + + def forward(self, x: torch.Tensor, a: int): + # this is the function to be compiled + return self.model(x, a) + + def __call__(self, x: torch.Tensor, a: int): + # let torch.compile compile twice + if len(self.compiled_codes) >= 2: + with self.dispatch_to_code(0): + return self.forward(x, a) + else: + return self.compiled_callable(x, a) + + +def test_torch_compile_wrapper(): + mod = MyMod() + wrapper = MyWrapper(mod) + x = torch.tensor([1.0]) + wrapper(x, 0) # first time, compile + wrapper(x, 1) # second time, compile + wrapper(x, 2) # third time, dispatch to the first compiled code + assert len(wrapper.compiled_codes) == 2 + From 75311863350fd488a15510b909df3e30176aaa88 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 09:31:09 -0700 Subject: [PATCH 09/22] add tests --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e406938647479..5364146033806 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -173,6 +173,7 @@ steps: - vllm/ commands: - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_wrapper.py - label: Vision Language Models Test # 42min From 31e9e7b27fb4fd21239cda662a96c25b3ebb2e57 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 09:53:53 -0700 Subject: [PATCH 10/22] update tests --- tests/compile/test_wrapper.py | 38 ++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 60b9b4d43acdd..3c2e4fcb96c9a 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther @@ -5,8 +7,10 @@ class MyMod(torch.nn.Module): - def forward(self, x: torch.Tensor, a: int): - return (x + a) * 2 + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + if cache is not None: + return x + cache + return x * 2 class MyWrapper(TorchCompileWrapperWithCustomDispacther): @@ -16,25 +20,31 @@ def __init__(self, model): compiled_callable = torch.compile(self.forward, backend="eager") super().__init__(compiled_callable) - def forward(self, x: torch.Tensor, a: int): + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled - return self.model(x, a) + return self.model(x, cache) - def __call__(self, x: torch.Tensor, a: int): + def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # let torch.compile compile twice - if len(self.compiled_codes) >= 2: - with self.dispatch_to_code(0): - return self.forward(x, a) + if len(self.compiled_codes) == 2: + dispatch_id = 0 if cache is None else 1 + with self.dispatch_to_code(dispatch_id): + return self.forward(x, cache) else: - return self.compiled_callable(x, a) + return self.compiled_callable(x, cache) def test_torch_compile_wrapper(): mod = MyMod() wrapper = MyWrapper(mod) - x = torch.tensor([1.0]) - wrapper(x, 0) # first time, compile - wrapper(x, 1) # second time, compile - wrapper(x, 2) # third time, dispatch to the first compiled code + x = torch.tensor([1]) + cache = torch.tensor([2]) + wrapper(x, None) # first time, compile + wrapper(x, cache) # second time, compile + + new_x = torch.tensor([3]) + assert wrapper(new_x, + None).item() == 6 # dispatch to the first compiled code + assert wrapper(new_x, + cache).item() == 5 # dispatch to the second compiled code assert len(wrapper.compiled_codes) == 2 - From ace38e252f5120467ca04034e196c71f77c3d086 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 09:58:20 -0700 Subject: [PATCH 11/22] multi wrappers --- tests/compile/test_wrapper.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 3c2e4fcb96c9a..cef516ade27eb 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -36,15 +36,24 @@ def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): def test_torch_compile_wrapper(): mod = MyMod() - wrapper = MyWrapper(mod) - x = torch.tensor([1]) - cache = torch.tensor([2]) - wrapper(x, None) # first time, compile - wrapper(x, cache) # second time, compile - - new_x = torch.tensor([3]) - assert wrapper(new_x, - None).item() == 6 # dispatch to the first compiled code - assert wrapper(new_x, - cache).item() == 5 # dispatch to the second compiled code - assert len(wrapper.compiled_codes) == 2 + wrappers = [] + for i in range(3): + torch._dynamo.reset() + wrapper = MyWrapper(mod) + wrappers.append(wrapper) + x = torch.tensor([1]) + wrapper(x, None) # profile run, compile + # create a cache tensor + cache = torch.tensor([2]) + wrapper(x, cache) # warm up with cache, recompile + + # for new input, dispatch to the compiled code directly + new_x = torch.tensor([3]) + assert wrapper(new_x, + None).item() == 6 # dispatch to the first compiled code + assert wrapper( + new_x, cache).item() == 5 # dispatch to the second compiled code + + for wrapper in wrappers: + # make sure they have independent compiled codes + assert len(wrapper.compiled_codes) == 2 From 31a9e0660d2171e4fabca51d60779bdbb44d79fa Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 10:02:43 -0700 Subject: [PATCH 12/22] use wrapper --- vllm/worker/tpu_model_runner.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index e81dbbd4aad3a..57e5b513ea3f4 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,6 +1,5 @@ import time from dataclasses import dataclass -from types import CodeType from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from unittest.mock import patch @@ -11,6 +10,7 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger @@ -598,10 +598,15 @@ def _execute_model(*args): return [SamplerOutput(sampler_outputs)] -class ModelWrapper: +class ModelWrapper(TorchCompileWrapperWithCustomDispacther): def __init__(self, model: nn.Module): self.model = model + compiled_forward = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_forward) def __call__(self, *args, is_prompt: bool = False, **kwargs): if len(ModelWrapper.compiled_codes) < 3: @@ -613,11 +618,11 @@ def __call__(self, *args, is_prompt: bool = False, **kwargs): # 2: for decode # dispatch to the compiled code directly, skip PyTorch if is_prompt: - ModelWrapper.forward.__code__ = ModelWrapper.compiled_codes[1] - return self.forward(*args, **kwargs) + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) else: - ModelWrapper.forward.__code__ = ModelWrapper.compiled_codes[2] - return self.forward(*args, **kwargs) + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) def forward( self, @@ -707,22 +712,6 @@ def forward( argmax_token_ids) return next_token_ids - compiled_forward = torch.compile(forward, - backend="openxla", - fullgraph=True, - dynamic=False) - - target_code = forward.__code__ - compiled_codes: List[CodeType] = [] - - @staticmethod - def collect_bytecode_hook(old, new): - global compiled_codes - if old is ModelWrapper.target_code: - ModelWrapper.compiled_codes.append(new) - - torch._dynamo.convert_frame.register_bytecode_hook(collect_bytecode_hook) - def _get_padded_prefill_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence From 0a349f554dc497693dcb5d87c0cd7c64dcbe2e90 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 10:07:50 -0700 Subject: [PATCH 13/22] fix --- vllm/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 57e5b513ea3f4..5d029d5ccb438 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -609,7 +609,7 @@ def __init__(self, model: nn.Module): super().__init__(compiled_forward) def __call__(self, *args, is_prompt: bool = False, **kwargs): - if len(ModelWrapper.compiled_codes) < 3: + if len(self.compiled_codes) < 3: # not fully compiled yet, let PyTorch handle it return self.compiled_forward(*args, **kwargs) # the 3 compiled codes are: From 12cb1645347e78fc5f35fa80b9ab17914584f86f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 10:11:26 -0700 Subject: [PATCH 14/22] fix --- vllm/worker/tpu_model_runner.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 5d029d5ccb438..30cbee108f616 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -602,16 +602,16 @@ class ModelWrapper(TorchCompileWrapperWithCustomDispacther): def __init__(self, model: nn.Module): self.model = model - compiled_forward = torch.compile(self.forward, - backend="openxla", - fullgraph=True, - dynamic=False) - super().__init__(compiled_forward) + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) def __call__(self, *args, is_prompt: bool = False, **kwargs): if len(self.compiled_codes) < 3: # not fully compiled yet, let PyTorch handle it - return self.compiled_forward(*args, **kwargs) + return self.compiled_callable(*args, **kwargs) # the 3 compiled codes are: # 0: for profiling # 1: for prompt From f483660ca6881a4f12144b6776effd0b8d513598 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 10:20:45 -0700 Subject: [PATCH 15/22] more explanation --- vllm/compilation/wrapper.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 1d2c1111a222e..7d2b0550864c5 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -60,7 +60,14 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): @contextmanager def dispatch_to_code(self, index: int): - """Context manager to dispatch to the compiled code.""" + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa self.__class__.forward.__code__ = self.compiled_codes[index] yield self.__class__.forward.__code__ = self.original_code_object From ec52afc32e2e87e7d2840176558b662832219403 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 10:53:11 -0700 Subject: [PATCH 16/22] add tests --- tests/tpu/test_custom_dispatcher.py | 6 ++++++ vllm/compilation/wrapper.py | 8 ++++++++ vllm/envs.py | 4 ++++ vllm/worker/tpu_model_runner.py | 5 +++-- 4 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 tests/tpu/test_custom_dispatcher.py diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py new file mode 100644 index 0000000000000..62ff33cb02f64 --- /dev/null +++ b/tests/tpu/test_custom_dispatcher.py @@ -0,0 +1,6 @@ +from ..utils import compare_two_settings + + +def test_custom_dispatcher(): + compare_two_settings("google/gemma-2b", [], [], + {"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 7d2b0550864c5..c3d863299dd06 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -7,6 +7,8 @@ import torch +import vllm.envs as envs + class TorchCompileWrapperWithCustomDispacther: """ @@ -27,6 +29,12 @@ def __init__(self, compiled_callable: Callable): self.compiled_codes: List[CodeType] = [] torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + # read the env var to determine whether to use the custom dispatcher + # subclasses can use this to switch between the custom dispatcher + # and the default Dynamo guard mechanism. + self.use_custom_dispatcher: bool = \ + envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. NOTE: this function can have additional arguments beyond the forward diff --git a/vllm/envs.py b/vllm/envs.py index 24e09ee0e055f..fdd21a6f9995c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -196,6 +196,10 @@ def get_default_config_root(): # Internal flag to enable Dynamo graph capture "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": + lambda: + (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in + ("true", "1")), # local rank of the process in the distributed setting, used to determine # the GPU device id diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 30cbee108f616..ca7248250adb0 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -609,8 +609,9 @@ def __init__(self, model: nn.Module): super().__init__(compiled_callable) def __call__(self, *args, is_prompt: bool = False, **kwargs): - if len(self.compiled_codes) < 3: - # not fully compiled yet, let PyTorch handle it + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it return self.compiled_callable(*args, **kwargs) # the 3 compiled codes are: # 0: for profiling From fabce9a862d9769bf24b3499b4fb20908ddced6e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 11:00:22 -0700 Subject: [PATCH 17/22] add package --- tests/tpu/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/tpu/__init__.py diff --git a/tests/tpu/__init__.py b/tests/tpu/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d From b9fff4c6421545927262bd6b3bcc8a28c590dcc4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 11:01:35 -0700 Subject: [PATCH 18/22] update tests --- tests/tpu/test_custom_dispatcher.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 62ff33cb02f64..7f3fb595321ad 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -2,5 +2,8 @@ def test_custom_dispatcher(): - compare_two_settings("google/gemma-2b", [], [], - {"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}) + compare_two_settings("google/gemma-2b", + arg1=["--enforce-eager"], + arg2=["--enforce-eager"], + env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}, + env2={}) From f5019fc735aa0535f30ad672c0be50ddf951da1a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 11:03:24 -0700 Subject: [PATCH 19/22] add tests --- .buildkite/run-tpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 335ffd83fcd7a..6989c94d46a89 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -12,4 +12,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" From e3692baba8b04d459b4a6a9819b44482742fa011 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 27 Aug 2024 11:51:24 -0700 Subject: [PATCH 20/22] add init --- vllm/compilation/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/compilation/__init__.py diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d From 746036ce8d5bf97441d53cbd5fcc9d1a4f85223a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 28 Aug 2024 15:26:16 -0700 Subject: [PATCH 21/22] Update vllm/worker/tpu_model_runner.py Co-authored-by: Woosuk Kwon --- vllm/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index ca7248250adb0..42630b65bbaee 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -608,7 +608,7 @@ def __init__(self, model: nn.Module): dynamic=False) super().__init__(compiled_callable) - def __call__(self, *args, is_prompt: bool = False, **kwargs): + def __call__(self, *args, is_prompt: bool, **kwargs): if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: # not fully compiled yet, or not using the custom dispatcher, # let PyTorch handle it From a0bac86733200122ce290b884fb581be7fe1235a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 28 Aug 2024 15:29:33 -0700 Subject: [PATCH 22/22] fix args --- vllm/worker/tpu_model_runner.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 42630b65bbaee..a7ceb84effe91 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -232,8 +232,15 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) + self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + num_samples, + kv_caches, + is_prompt=is_prompt) def warmup_model( self,