-
-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[MRV2] Extensible CG dispatch rework #35959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
WoosukKwon
merged 35 commits into
vllm-project:main
from
neuralmagic:lwilkinson/mrv2-cg-dispatch
Mar 9, 2026
Merged
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
d5f114c
wip
LucasWilkinson 81e3ac3
wip
LucasWilkinson 73a36ff
cleanup
LucasWilkinson 080ee62
wip
LucasWilkinson 4f3ed1a
clean
LucasWilkinson 29af75c
cleanup
LucasWilkinson f4626a5
cleanup
LucasWilkinson 3e142a0
Refactor cudagraph capture: base class handles warmup and capture mec…
LucasWilkinson 18f3e24
cleanup
LucasWilkinson 76b7d6f
Simplify Eagle cudagraph: FULL mode only, align with main
LucasWilkinson ac055f2
Make num_reqs_padded required in gather_block_tables
LucasWilkinson 3ad20f7
cleanup
LucasWilkinson 5d466b6
cleanup
LucasWilkinson 1dc0457
cleanup
LucasWilkinson db65896
cleanup
LucasWilkinson c3419a4
cleanup
LucasWilkinson 4d85377
cleanup
LucasWilkinson d3c93ed
cleanup
LucasWilkinson 91aa819
cleanup
LucasWilkinson 43bb5ca
cleanup
LucasWilkinson 6e6c062
cleanup
LucasWilkinson 465ca0d
cleanup
LucasWilkinson 64db442
cleanup
LucasWilkinson d0f8bfe
cleanup
LucasWilkinson 6e8b57e
cleanup
LucasWilkinson 1141328
fix
LucasWilkinson d9fe9a5
review comments + fix PIECEWISE only
LucasWilkinson ad1834c
fix deserialize
LucasWilkinson fde7e34
cleanup DP and don't request pad for PIECEWISE
LucasWilkinson 0198457
cleanup
LucasWilkinson 54b9c79
fix hang
LucasWilkinson 03b5789
fix
LucasWilkinson 82fbc2d
cleanup
LucasWilkinson aba1411
cleanup
LucasWilkinson f424970
fix(eagle): use dedicated cudagraph pool to prevent memory overlap
LucasWilkinson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -104,19 +104,24 @@ def apply_staged_writes(self) -> None: | |
| self.num_blocks.copy_to_uva() | ||
|
|
||
| def gather_block_tables( | ||
| self, idx_mapping: torch.Tensor | ||
| self, | ||
| idx_mapping: torch.Tensor, | ||
| num_reqs_padded: int, | ||
| ) -> tuple[torch.Tensor, ...]: | ||
| num_reqs = idx_mapping.shape[0] | ||
| _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)]( | ||
| # Launch kernel with num_reqs_padded to fuse zeroing of padded rows. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: Why do we need to zero out the block table at all? |
||
| _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)]( | ||
| idx_mapping, | ||
| self.block_table_ptrs, | ||
| self.input_block_table_ptrs, | ||
| self.block_table_strides, | ||
| self.num_blocks.gpu, | ||
| self.num_blocks.gpu.stride(0), | ||
| num_reqs, | ||
| self.input_block_tables[0].shape[1], # max_num_blocks | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this safe? IIUC, block tables of different KV cache groups might have different sizes. |
||
| BLOCK_SIZE=1024, # type: ignore | ||
| ) | ||
| return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) | ||
| return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables) | ||
|
|
||
| def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]: | ||
| # NOTE(woosuk): The output may be used for CUDA graph capture. | ||
|
|
@@ -130,6 +135,7 @@ def compute_slot_mappings( | |
| idx_mapping: torch.Tensor, | ||
| query_start_loc: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| num_tokens_padded: int, | ||
| ) -> torch.Tensor: | ||
| num_reqs = idx_mapping.shape[0] | ||
| num_tokens = positions.shape[0] | ||
|
|
@@ -151,7 +157,7 @@ def compute_slot_mappings( | |
| PAD_ID=PAD_SLOT_ID, | ||
| TRITON_BLOCK_SIZE=1024, # type: ignore | ||
| ) | ||
| return self.slot_mappings[:, :num_tokens] | ||
| return self.slot_mappings[:, :num_tokens_padded] | ||
|
|
||
| def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor: | ||
| # Fill the entire slot_mappings tensor, not just the first `num_tokens` entries. | ||
|
|
@@ -173,21 +179,31 @@ def _gather_block_tables_kernel( | |
| block_table_strides, # [num_kv_cache_groups] | ||
| num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] | ||
| num_blocks_stride, | ||
| num_reqs, # actual number of requests (for padding) | ||
| max_num_blocks, # stride for zeroing padded rows | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| # kv cache group id | ||
| group_id = tl.program_id(0) | ||
| batch_idx = tl.program_id(1) | ||
| req_idx = tl.load(batch_idx_to_req_idx + batch_idx) | ||
|
|
||
| stride = tl.load(block_table_strides + group_id) | ||
| dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32) | ||
| dst_row_ptr = dst_block_table_ptr + batch_idx * stride | ||
|
|
||
| if batch_idx >= num_reqs: | ||
| # Zero out padded rows. | ||
| for i in tl.range(0, max_num_blocks, BLOCK_SIZE): | ||
| offset = i + tl.arange(0, BLOCK_SIZE) | ||
| tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks) | ||
| return | ||
|
|
||
| req_idx = tl.load(batch_idx_to_req_idx + batch_idx) | ||
| group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride | ||
| num_blocks = tl.load(group_num_blocks_ptr + req_idx) | ||
|
|
||
| stride = tl.load(block_table_strides + group_id) | ||
| src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32) | ||
| src_row_ptr = src_block_table_ptr + req_idx * stride | ||
| dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32) | ||
| dst_row_ptr = dst_block_table_ptr + batch_idx * stride | ||
|
|
||
| for i in tl.range(0, num_blocks, BLOCK_SIZE): | ||
| offset = i + tl.arange(0, BLOCK_SIZE) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this implicit conversion can be confusing and error prone.