Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d5f114c
wip
LucasWilkinson Mar 4, 2026
81e3ac3
wip
LucasWilkinson Mar 4, 2026
73a36ff
cleanup
LucasWilkinson Mar 4, 2026
080ee62
wip
LucasWilkinson Mar 4, 2026
4f3ed1a
clean
LucasWilkinson Mar 4, 2026
29af75c
cleanup
LucasWilkinson Mar 4, 2026
f4626a5
cleanup
LucasWilkinson Mar 4, 2026
3e142a0
Refactor cudagraph capture: base class handles warmup and capture mec…
LucasWilkinson Mar 4, 2026
18f3e24
cleanup
LucasWilkinson Mar 5, 2026
76b7d6f
Simplify Eagle cudagraph: FULL mode only, align with main
LucasWilkinson Mar 5, 2026
ac055f2
Make num_reqs_padded required in gather_block_tables
LucasWilkinson Mar 5, 2026
3ad20f7
cleanup
LucasWilkinson Mar 5, 2026
5d466b6
cleanup
LucasWilkinson Mar 5, 2026
1dc0457
cleanup
LucasWilkinson Mar 5, 2026
db65896
cleanup
LucasWilkinson Mar 5, 2026
c3419a4
cleanup
LucasWilkinson Mar 5, 2026
4d85377
cleanup
LucasWilkinson Mar 5, 2026
d3c93ed
cleanup
LucasWilkinson Mar 5, 2026
91aa819
cleanup
LucasWilkinson Mar 5, 2026
43bb5ca
cleanup
LucasWilkinson Mar 5, 2026
6e6c062
cleanup
LucasWilkinson Mar 5, 2026
465ca0d
cleanup
LucasWilkinson Mar 5, 2026
64db442
cleanup
LucasWilkinson Mar 5, 2026
d0f8bfe
cleanup
LucasWilkinson Mar 5, 2026
6e8b57e
cleanup
LucasWilkinson Mar 5, 2026
1141328
fix
LucasWilkinson Mar 5, 2026
d9fe9a5
review comments + fix PIECEWISE only
LucasWilkinson Mar 6, 2026
ad1834c
fix deserialize
LucasWilkinson Mar 6, 2026
fde7e34
cleanup DP and don't request pad for PIECEWISE
LucasWilkinson Mar 6, 2026
0198457
cleanup
LucasWilkinson Mar 6, 2026
54b9c79
fix hang
LucasWilkinson Mar 6, 2026
03b5789
fix
LucasWilkinson Mar 6, 2026
82fbc2d
cleanup
LucasWilkinson Mar 6, 2026
aba1411
cleanup
LucasWilkinson Mar 6, 2026
f424970
fix(eagle): use dedicated cudagraph pool to prevent memory overlap
LucasWilkinson Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def is_valid_runtime_mode(self) -> bool:
def __str__(self) -> str:
return self.name

def __bool__(self) -> bool:
return self != CUDAGraphMode.NONE
Comment on lines +100 to +101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this implicit conversion can be confusing and error prone.



@config
class PassConfig:
Expand Down
32 changes: 24 additions & 8 deletions vllm/v1/worker/gpu/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,24 @@ def apply_staged_writes(self) -> None:
self.num_blocks.copy_to_uva()

def gather_block_tables(
self, idx_mapping: torch.Tensor
self,
idx_mapping: torch.Tensor,
num_reqs_padded: int,
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
# Launch kernel with num_reqs_padded to fuse zeroing of padded rows.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why do we need to zero out the block table at all?

_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks.gpu,
self.num_blocks.gpu.stride(0),
num_reqs,
self.input_block_tables[0].shape[1], # max_num_blocks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? IIUC, block tables of different KV cache groups might have different sizes.

BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables)

def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
# NOTE(woosuk): The output may be used for CUDA graph capture.
Expand All @@ -130,6 +135,7 @@ def compute_slot_mappings(
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
num_tokens_padded: int,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
Expand All @@ -151,7 +157,7 @@ def compute_slot_mappings(
PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
return self.slot_mappings[:, :num_tokens_padded]

def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
Expand All @@ -173,21 +179,31 @@ def _gather_block_tables_kernel(
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
num_reqs, # actual number of requests (for padding)
max_num_blocks, # stride for zeroing padded rows
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)

stride = tl.load(block_table_strides + group_id)
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride

if batch_idx >= num_reqs:
# Zero out padded rows.
for i in tl.range(0, max_num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks)
return

req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)

stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride

for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
Expand Down
Loading