-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Compile] Conditional compilation. Introduce compile_ranges #24252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ilmarkov
wants to merge
164
commits into
vllm-project:main
Choose a base branch
from
neuralmagic:imarkov/conditional_compilation_ranges
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 128 commits
Commits
Show all changes
164 commits
Select commit
Hold shift + click to select a range
21d7d67
Functionalized patterns in prep for utility
ProExpertProg f3b4cf1
TEMP Mostly working
ProExpertProg cdad3c0
TEMP: fixed rmsnorm issue (TODO assert dtypes in fused norm_quant ker…
ProExpertProg 8e4a56f
rms works fully now, had to remove more conversions (and add them in …
ProExpertProg e151e6d
quant works except (torch,torch)
ProExpertProg 14fdc8b
quant with fix for pure torch, broke others
ProExpertProg 05a65f3
ALL WORKS
ProExpertProg e6b394e
Add TODO
ProExpertProg d96913a
Cleanup test_fusion.py, added extra layer of rms/quant
ProExpertProg b172747
Functionalize attn+quant patterns
ProExpertProg 1ae80c6
Move global vllm_config to pass manager
ProExpertProg 77835fd
Attention fusion works with custom ops
ProExpertProg 1277999
Remove V0 attn fusion test
ProExpertProg d843a67
Add triton attn test to attn+quant fusion
ProExpertProg cdd1529
Flat product for better test names/visibility
ProExpertProg 141a37e
Fix rmsnorm
ProExpertProg c6d6c3b
Refactor E2E attn fusion test
ProExpertProg 490ac86
Add TP=2 test (untested)
ProExpertProg d0b1b56
improve tests by adding more cases
ProExpertProg 47b4688
TEMP working on caplog
ProExpertProg ae7f56f
Temp MP workaround P2
ProExpertProg eb899a4
Temp MP workaround P3
ProExpertProg a2aa978
Test for caplog utils
ProExpertProg 21a9f9f
Fixed tests, passing with 2.8, 2.9 tbd
ProExpertProg 66a35a9
Update tests/compile/backend.py
ProExpertProg 7eb1364
Update csrc/layernorm_kernels.cu
ProExpertProg 5fef180
clean up fullgraph tests
ProExpertProg db479ae
TEMP allreduce fusion
ProExpertProg 54189a9
allreduce fusion working (custom ops on)
ProExpertProg b7f52bf
allreduce fusion working with/without custom ops (except fp4)
ProExpertProg d09a278
allreduce fusion working with/without custom ops (with fp4)
ProExpertProg c8675ff
log depyf folder, fix context for TestBackend, fix pattern dump
ProExpertProg d3f95fe
fullgraph allreduce test update requirements
ProExpertProg 4dbfcf7
Move e2e tests to new file, add to test pipeline
ProExpertProg 31d0127
Add e2e fusions to fullgraph test (should work with Triton backend), …
ProExpertProg c653d24
Fix spelling, precommit
ProExpertProg 1756f67
add back fp4
ProExpertProg 5619bc3
clean up e2e tests
ProExpertProg 32989d8
add pattern for final allreduce in model
ProExpertProg 46ee626
add more comprehensive testing for quantfp8 (-rmsnorm+-quant still fa…
ProExpertProg a1c7fdb
add more comprehensive testing for allreduce-rmsnorm, fix fp4 (-rmsno…
ProExpertProg c3264d8
Fix partial match rmsnorm+quant, fix allreduce+rmsnorm match
ProExpertProg 095277c
Simplify matcher utils by using RMSNorm.forward_static
ProExpertProg 52f78ce
Add allreduce test to 2-gpu test
ProExpertProg 1b1a63e
Fix e2e allreduce fusion test
ProExpertProg 0d6e550
fix func test
ProExpertProg 26892df
fix pass manager test
ProExpertProg 3547b87
fix sequence parallelism test
ProExpertProg af1ffa7
PR review
ProExpertProg 97b3ff2
Merge remote-tracking branch 'upstream/main' into luka/custom-op-matc…
ProExpertProg b5f89e5
Cleanup test_full_graph.py
ProExpertProg f6429e4
Cleanup test_fusion_attn.py
ProExpertProg 8a363d3
Slight improvement for E2E fusion
ProExpertProg 12a7c6d
Tests & docs for flat_product
ProExpertProg db16ee1
Merge branch 'main' into luka/custom-op-matching-2
ProExpertProg 8ffb474
Remove/fix TODOs
ProExpertProg 2a6299c
Fix e2e test patterns
ProExpertProg 465ce58
Update tests/compile/test_fusion.py
ProExpertProg bb0254a
Merge branch 'main' into luka/custom-op-matching-2
ProExpertProg bcd95b5
Fix func test
ProExpertProg db2b1c7
Smaller model for e2e fusion test
ProExpertProg a3ebf0a
fix fp8 quant tests
ProExpertProg 3943257
Restore original torch.Parameter behavior in RMSNorm
ProExpertProg 532cbcf
Add comment to test_logger
ProExpertProg 7e6f5b3
add flat_product example
ProExpertProg 24f1298
PR comments: cleanup fusion passes, & matching
ProExpertProg de7405b
PR comments: add _custom_op suffix
ProExpertProg 6253d5b
Add e2e to L40 distributed, move tests to start of B200 distributed
ProExpertProg 876ef22
Fix tests, PR feedback
ProExpertProg e99a759
Break up B200 tests, move allreduce to H200
ProExpertProg a226864
Merge branch 'main' into luka/custom-op-matching-2
ProExpertProg ae581e1
Fix attention fusion test numerics
ProExpertProg c03b29b
Remove inductor graph partition from unit test (included in e2e tests)
ProExpertProg d2e0489
Relax tolerance for L40 fusion test
ProExpertProg 65ef5fd
Merge branch 'main' into luka/custom-op-matching-2
ProExpertProg d4fe977
Fix NamedTuple
ProExpertProg 6319e39
Update test durations
ProExpertProg e34d36d
More tweaking of precision
ProExpertProg f72ee43
Split original pr
ilmarkov c4c0215
Update bench
ilmarkov 309d79e
Update threshold configuration
ilmarkov afcfd73
Move all_reduce from custom op in fused_moe
ilmarkov 0248dcd
Linter fixes
ilmarkov 18e4771
Upd
ilmarkov 1debd8e
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov 9516d2b
Upd after review
ilmarkov b789044
Update fused_moe
ilmarkov 4001935
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov 6077616
Address comments
ilmarkov afc8af8
Remove bench_compile
ilmarkov c3af2af
Split PR. Second part. Compile ranges
ilmarkov 0cbb065
Remove general shape graph
ilmarkov d5392f5
Add test to test pipeline
ilmarkov 027c9eb
Fix pre-commit
ilmarkov b2992d3
Upd
ilmarkov 3499384
Upd config
ilmarkov 5336ee6
Fix
ilmarkov 4958474
Priotitize compile_sizes
ilmarkov 04306ed
Fix inductor config
ilmarkov 9dc4eea
Laith's fix
ilmarkov 2c63f0b
Upd
ilmarkov 67f7ae1
Update config
ilmarkov 8b8d01d
Merge branch 'imarkov/fused_allreduce_torch_native' into imarkov/cond…
ilmarkov fcebc21
Add caching
ilmarkov 65151bc
Address comments
ilmarkov 1f7afdb
Add debug log
ilmarkov 8da1585
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov df22202
Update benchmark
ilmarkov a21de2b
Fix
ilmarkov 45f4093
Update bench and constants
ilmarkov c26e056
Rename in benchmark
ilmarkov 1bee5a6
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov bcc0cc0
Add max_token_num to object
ilmarkov 43b163c
Add test
ilmarkov 71c6b72
Update comments
ilmarkov ada24e6
Merge branch 'imarkov/fused_allreduce_torch_native' into imarkov/cond…
ilmarkov 6766e4f
Update fakify for compile sizes
ilmarkov af87d7a
Linter fix
ilmarkov 56273da
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov 459f71c
Merge branch 'imarkov/fused_allreduce_torch_native' into imarkov/cond…
ilmarkov 2785e4d
Minor updates
ilmarkov 1f83a66
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ProExpertProg b4c1b1d
Address the review
ilmarkov ab33605
Merge branch 'main' into imarkov/fused_allreduce_torch_native
robertgshaw2-redhat 3fac39b
Merge branch 'main' into imarkov/fused_allreduce_torch_native
ilmarkov b0a3884
Fix SP
ilmarkov a3e7bdc
Merge branch 'imarkov/fused_allreduce_torch_native' into imarkov/cond…
ilmarkov a810969
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov 319abd5
Remove dynamic shape
ilmarkov d168de0
Make ranges inclusive-inclusive
ilmarkov b65e752
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov af10400
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov 6c05919
Add test for inductor cache hits
ilmarkov 03637e7
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov 3f72483
Address comments
ilmarkov 9b00ebc
Address comments
ilmarkov 8a40ac6
Update test
ilmarkov ef05682
Address comments
ilmarkov 63af962
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov 7647089
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov ee89388
Update test utils
ilmarkov 925e87d
Fix pre-commit after merge
ilmarkov 809e170
Fix tests
ilmarkov e07c939
Add fixture instead of decorator
ilmarkov f4db45c
Fix re-used compilation config
ilmarkov 97a8d58
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov 4f280ce
Fix e2e
ilmarkov f714957
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov b27f89d
Fix e2e adapt to number of compile ranges
ilmarkov eedc70e
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov cc8f2f8
Slight fix of test
ilmarkov d1dd4db
Fix tests after refactor
ilmarkov a2b67a4
Simplify
ilmarkov 0776364
Address comments
ilmarkov 42bf355
Merge branch 'main' into imarkov/conditional_compilation_ranges
ilmarkov ca832fc
Merge remote-tracking branch 'upstream/main' into imarkov/conditional…
ProExpertProg ba90b9e
Only warm up model if mode=VLLM_COMPILE
ProExpertProg 771203f
Fix capture-sizes
ProExpertProg 0e0eab9
Fix doc range
ProExpertProg 3d2c36b
pre-commit
ProExpertProg 18ff16e
Fix types for precommit
ProExpertProg 6bc8258
Update vllm/v1/worker/gpu_worker.py
ProExpertProg c43458b
Merge remote-tracking branch 'upstream/main' into imarkov/conditional…
ProExpertProg f4c0ae7
Check that the pass was skipped in other range
ProExpertProg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import torch | ||
| from torch import fx as fx | ||
| from torch import nn | ||
|
|
||
| # This import automatically registers `torch.ops.silly.attention` | ||
| import tests.compile.silly_attention # noqa | ||
| from vllm.compilation.counter import compilation_counter | ||
| from vllm.compilation.decorators import support_torch_compile | ||
| from vllm.compilation.inductor_pass import ( | ||
| InductorPass, | ||
| get_pass_context, | ||
| ) | ||
| from vllm.config import ( | ||
| VllmConfig, | ||
| set_current_vllm_config, | ||
| ) | ||
| from vllm.config.compilation import CompilationConfig, CompilationMode | ||
| from vllm.config.scheduler import SchedulerConfig | ||
| from vllm.config.utils import Range | ||
| from vllm.forward_context import set_forward_context | ||
|
|
||
| BATCH_SIZE = 64 | ||
| MLP_SIZE = 128 | ||
|
|
||
|
|
||
| @support_torch_compile | ||
| class TestModel(nn.Module): | ||
| def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: | ||
| super().__init__() | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| x = x + x | ||
| attn_output = torch.empty_like(x) | ||
| torch.ops.silly.attention(x, x, x, attn_output) | ||
| x = attn_output | ||
| x = x * 3 | ||
| return x | ||
|
|
||
|
|
||
| @torch.inference_mode | ||
| def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): | ||
| with set_forward_context({}, vllm_config=vllm_config): | ||
| model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) | ||
| for batch_size in batch_sizes: | ||
| model(torch.randn(batch_size, MLP_SIZE).cuda()) | ||
ilmarkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class PostGradPassManagerCheckRanges(InductorPass): | ||
ilmarkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self, ranges: list[Range]): | ||
| self.ranges = ranges | ||
| self.num_calls = 0 | ||
|
|
||
| def __call__(self, graph: fx.Graph): | ||
| compile_range = get_pass_context().compile_range | ||
| assert compile_range in self.ranges, ( | ||
| f"Compile range {compile_range} not in {self.ranges}" | ||
| ) | ||
| self.num_calls += 1 | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def uuid(self) -> str: | ||
| state = { | ||
| "ranges": [str(range) for range in self.ranges], | ||
| "current_compile_range": str(get_pass_context().compile_range), | ||
| } | ||
| return InductorPass.hash_dict(state) | ||
ilmarkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def test_compile_ranges(): | ||
| post_grad_pass_manager = PostGradPassManagerCheckRanges( | ||
| [ | ||
| Range(start=1, end=8), | ||
| Range(start=8, end=32), | ||
| Range(start=32, end=8193), | ||
ilmarkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ] | ||
| ) | ||
| vllm_config = VllmConfig( | ||
| scheduler_config=SchedulerConfig( | ||
| max_num_batched_tokens=8192, | ||
| ), | ||
| compilation_config=CompilationConfig( | ||
| mode=CompilationMode.VLLM_COMPILE, | ||
| compile_ranges_split_points=[8, 32], | ||
| inductor_compile_config={ | ||
| "post_grad_custom_post_pass": post_grad_pass_manager, | ||
| # Disable inductor cache to get the number of passes correctly | ||
| "force_disable_caches": True, | ||
ilmarkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| }, | ||
| ), | ||
| ) | ||
|
|
||
| with set_current_vllm_config(vllm_config): | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() | ||
| batch_sizes = [1, 4, 16, 24, 48, 64] | ||
ilmarkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # A has support_torch_compile | ||
ProExpertProg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| with compilation_counter.expect( | ||
| num_graphs_seen=1, | ||
| num_piecewise_graphs_seen=1, | ||
| num_backend_compilations=3, | ||
| # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen | ||
ilmarkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ): | ||
| run_model(vllm_config, model, batch_sizes) | ||
| assert post_grad_pass_manager.num_calls == 3 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.