-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Compile] Conditional compilation. Introduce compile_ranges #24252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 148 commits
21d7d67
f3b4cf1
cdad3c0
8e4a56f
e151e6d
14fdc8b
05a65f3
e6b394e
d96913a
b172747
1ae80c6
77835fd
1277999
d843a67
cdd1529
141a37e
c6d6c3b
490ac86
d0b1b56
47b4688
ae7f56f
eb899a4
a2aa978
21a9f9f
66a35a9
7eb1364
5fef180
db479ae
54189a9
b7f52bf
d09a278
c8675ff
d3f95fe
4dbfcf7
31d0127
c653d24
1756f67
5619bc3
32989d8
46ee626
a1c7fdb
c3264d8
095277c
52f78ce
1b1a63e
0d6e550
26892df
3547b87
af1ffa7
97b3ff2
b5f89e5
f6429e4
8a363d3
12a7c6d
db16ee1
8ffb474
2a6299c
465ce58
bb0254a
bcd95b5
db2b1c7
a3ebf0a
3943257
532cbcf
7e6f5b3
24f1298
de7405b
6253d5b
876ef22
e99a759
a226864
ae581e1
c03b29b
d2e0489
65ef5fd
d4fe977
6319e39
e34d36d
f72ee43
c4c0215
309d79e
afcfd73
0248dcd
18e4771
1debd8e
9516d2b
b789044
4001935
6077616
afc8af8
c3af2af
0cbb065
d5392f5
027c9eb
b2992d3
3499384
5336ee6
4958474
04306ed
9dc4eea
2c63f0b
67f7ae1
8b8d01d
fcebc21
65151bc
1f7afdb
8da1585
df22202
a21de2b
45f4093
c26e056
1bee5a6
bcc0cc0
43b163c
71c6b72
ada24e6
6766e4f
af87d7a
56273da
459f71c
2785e4d
1f83a66
b4c1b1d
ab33605
3fac39b
b0a3884
a3e7bdc
a810969
319abd5
d168de0
b65e752
af10400
6c05919
03637e7
3f72483
9b00ebc
8a40ac6
ef05682
63af962
7647089
ee89388
925e87d
809e170
e07c939
f4db45c
97a8d58
4f280ce
f714957
b27f89d
eedc70e
cc8f2f8
d1dd4db
a2b67a4
0776364
42bf355
ca832fc
ba90b9e
771203f
0e0eab9
3d2c36b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| from torch import fx as fx | ||
| from torch import nn | ||
|
|
||
| # This import automatically registers `torch.ops.silly.attention` | ||
| import tests.compile.silly_attention # noqa | ||
| from vllm.compilation.counter import compilation_counter | ||
| from vllm.compilation.decorators import support_torch_compile | ||
| from vllm.compilation.inductor_pass import ( | ||
| InductorPass, | ||
| get_pass_context, | ||
| ) | ||
| from vllm.config import ( | ||
| VllmConfig, | ||
| set_current_vllm_config, | ||
| ) | ||
| from vllm.config.compilation import CompilationConfig, CompilationMode | ||
| from vllm.config.scheduler import SchedulerConfig | ||
| from vllm.config.utils import Range | ||
| from vllm.forward_context import set_forward_context | ||
|
|
||
| BATCH_SIZE = 64 | ||
| MLP_SIZE = 128 | ||
|
|
||
|
|
||
| @support_torch_compile | ||
| class TestModel(nn.Module): | ||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: | ||
| super().__init__() | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| x = x + x | ||
| attn_output = torch.empty_like(x) | ||
| torch.ops.silly.attention(x, x, x, attn_output) | ||
| x = attn_output | ||
| x = x * 3 | ||
| return x | ||
|
|
||
|
|
||
| @torch.inference_mode | ||
| def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): | ||
| with set_forward_context({}, vllm_config=vllm_config): | ||
| model(torch.randn(BATCH_SIZE, MLP_SIZE)) | ||
| for batch_size in batch_sizes: | ||
| model(torch.randn(batch_size, MLP_SIZE)) | ||
|
|
||
|
|
||
| class PostGradRangeChecker(InductorPass): | ||
| def __init__(self, ranges: list[Range]): | ||
| self.ranges = ranges | ||
| self.num_calls = 0 | ||
|
|
||
| def __call__(self, graph: fx.Graph): | ||
| compile_range = get_pass_context().compile_range | ||
| assert compile_range in self.ranges, ( | ||
| f"Compile range {compile_range} not in {self.ranges}" | ||
| ) | ||
| self.num_calls += 1 | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def uuid(self) -> str: | ||
| state: dict[str, Any] = {} | ||
| return InductorPass.hash_dict(state) | ||
|
|
||
|
|
||
| def test_compile_ranges(use_fresh_inductor_cache): | ||
| post_grad_range_checker = PostGradRangeChecker( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come this works without disabling the vllm cache?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably clean inductor cache allows us to avoid cache hits of vllm cache
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm ok |
||
| [ | ||
| Range(start=1, end=8), | ||
| Range(start=16, end=16), | ||
| Range(start=9, end=32), | ||
| Range(start=64, end=64), | ||
| Range(start=33, end=8192), | ||
| ] | ||
| ) | ||
| torch.set_default_device("cuda") | ||
| vllm_config = VllmConfig( | ||
| scheduler_config=SchedulerConfig( | ||
| max_num_batched_tokens=8192, | ||
| ), | ||
| compilation_config=CompilationConfig( | ||
| mode=CompilationMode.VLLM_COMPILE, | ||
| compile_ranges_split_points=[8, 32], | ||
| compile_sizes=[16, 64, 128], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I wonder if we shall we call those now specialize sizes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah we can do in follow-up |
||
| inductor_compile_config={ | ||
| "post_grad_custom_post_pass": post_grad_range_checker, | ||
| }, | ||
| ), | ||
| ) | ||
|
|
||
| with set_current_vllm_config(vllm_config): | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model = TestModel(vllm_config=vllm_config, prefix="").eval() | ||
| # Number of compilations: 3 for each compile range + 2 compile sizes | ||
| batch_sizes = [1, 4, 16, 24, 48, 64, 8192] | ||
|
|
||
| with compilation_counter.expect( | ||
| num_graphs_seen=1, | ||
| num_piecewise_graphs_seen=1, | ||
| num_backend_compilations=5, | ||
| ): | ||
| run_model(vllm_config, model, batch_sizes) | ||
| assert post_grad_range_checker.num_calls == 5 | ||
|
|
||
|
|
||
| def test_compile_config_get_compile_ranges(): | ||
| compilation_config = CompilationConfig( | ||
| compile_ranges_split_points=[8, 32], | ||
| ) | ||
| VllmConfig( | ||
| scheduler_config=SchedulerConfig( | ||
| max_num_batched_tokens=8192, | ||
| ), | ||
| compilation_config=compilation_config, | ||
| ) | ||
| assert compilation_config.get_compile_ranges() == [ | ||
| Range(start=1, end=8), | ||
| Range(start=9, end=32), | ||
| Range(start=33, end=8192), | ||
| ] | ||
|
|
||
|
|
||
| def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache): | ||
| # To force multiple compilations, we disable the compile cache | ||
| monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") | ||
|
|
||
| post_grad_range_checker = PostGradRangeChecker( | ||
| ranges=[ | ||
| Range(start=1, end=8), | ||
| Range(start=9, end=8192), | ||
| ] | ||
| ) | ||
| scheduler_config = SchedulerConfig( | ||
| max_num_batched_tokens=8192, | ||
| ) | ||
| torch.set_default_device("cuda") | ||
|
|
||
| def create_vllm_config(): | ||
| return VllmConfig( | ||
| scheduler_config=scheduler_config, | ||
| compilation_config=CompilationConfig( | ||
| mode=CompilationMode.VLLM_COMPILE, | ||
| compile_ranges_split_points=[8], | ||
| inductor_compile_config={ | ||
| "post_grad_custom_post_pass": post_grad_range_checker, | ||
| }, | ||
| ), | ||
| ) | ||
|
|
||
| vllm_config_1 = create_vllm_config() | ||
| with set_current_vllm_config(vllm_config_1): | ||
| model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval() | ||
| batch_sizes = [1, 16] | ||
| run_model(vllm_config_1, model1, batch_sizes) | ||
| assert post_grad_range_checker.num_calls == 2 | ||
|
|
||
| post_grad_range_checker.num_calls = 0 | ||
| # Create a new vllm config with the new pass context | ||
| vllm_config_2 = create_vllm_config() | ||
| with set_current_vllm_config(vllm_config_2): | ||
| model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() | ||
| batch_sizes = [4, 32] | ||
| run_model(vllm_config_2, model2, batch_sizes) | ||
| # Check that cache is used, so the number of calls | ||
| # should be 0 | ||
| assert post_grad_range_checker.num_calls == 0 | ||
Uh oh!
There was an error while loading. Please reload this page.