diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index ee18401f6304..9261ff4da28d 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -6,8 +6,9 @@ from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv +from vllm.utils.platform_utils import is_uva_available +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor from vllm.v1.attention.backends.utils import PAD_SLOT_ID -from vllm.v1.utils import CpuGpuBuffer class BlockTables: @@ -18,51 +19,53 @@ def __init__( max_num_batched_tokens: int, max_model_len: int, device: torch.device, - pin_memory: bool, ): self.block_sizes = block_sizes self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.max_model_len = max_model_len self.device = device - self.pin_memory = pin_memory + + if not is_uva_available(): + raise RuntimeError("UVA is not available") self.num_kv_cache_groups = len(self.block_sizes) # num_kv_cache_groups x [max_num_reqs, max_num_blocks] - self.block_tables: list[torch.Tensor] = [] + self.block_tables: list[UvaBuffer] = [] for i in range(self.num_kv_cache_groups): block_size = self.block_sizes[i] max_num_blocks = cdiv(self.max_model_len, block_size) - block_table = torch.zeros( + block_table = UvaBuffer( self.max_num_reqs, max_num_blocks, dtype=torch.int32, - device=self.device, ) self.block_tables.append(block_table) - self.block_table_ptrs = self._make_ptr_tensor(self.block_tables) - - # Block tables used for model's forward pass. - # num_kv_cache_groups x [max_num_reqs, max_num_blocks] - self.input_block_tables: list[torch.Tensor] = [ - torch.zeros_like(block_table) for block_table in self.block_tables - ] - self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables) - + self.block_table_ptrs = self._make_ptr_tensor( + [b.gpu for b in self.block_tables] + ) self.block_table_strides = torch.tensor( - [b.stride(0) for b in self.block_tables], + [b.gpu.stride(0) for b in self.block_tables], dtype=torch.int64, device=self.device, ) + self.block_sizes_tensor = torch.tensor( self.block_sizes, dtype=torch.int32, device=self.device ) - self.num_blocks = torch.zeros( + self.num_blocks = UvaBuffer( self.num_kv_cache_groups, self.max_num_reqs, dtype=torch.int32, - device=self.device, ) + + # Block tables used for model's forward pass. + # num_kv_cache_groups x [max_num_reqs, max_num_blocks] + self.input_block_tables: list[torch.Tensor] = [ + torch.zeros_like(b.gpu) for b in self.block_tables + ] + self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables) + self.slot_mappings = torch.zeros( self.num_kv_cache_groups, self.max_num_batched_tokens, @@ -70,74 +73,36 @@ def __init__( device=self.device, ) - # Misc buffers. - self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool) - self.cu_num_new_blocks = self._make_buffer( - self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32 - ) - - def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer( - *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device - ) - def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses. ptrs_tensor_cpu = torch.tensor( [t.data_ptr() for t in x], dtype=torch.uint64, device="cpu", - pin_memory=self.pin_memory, + pin_memory=True, ) return ptrs_tensor_cpu.to(self.device, non_blocking=True) def append_block_ids( self, - # [num_reqs] - req_indices: list[int], - # [num_kv_cache_groups, num_reqs + 1] - cu_num_new_blocks: tuple[list[int], ...], - # [num_kv_cache_groups, num_new_blocks] + req_index: int, new_block_ids: tuple[list[int], ...], - # [num_reqs] - overwrite: list[bool], + overwrite: bool, ) -> None: - num_reqs = len(req_indices) - self.req_indices.np[:num_reqs] = req_indices - self.overwrite.np[:num_reqs] = overwrite - for i in range(self.num_kv_cache_groups): - self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i] - - # NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's - # no clear upper bound to the number of new blocks in a single step. - # NOTE(woosuk): The buffer has to be cached, because otherwise we cannot - # guarantee that the buffer is not freed before the copy is completed. - self.new_block_ids_cpu = torch.empty( - self.num_kv_cache_groups, - max(len(x) for x in new_block_ids), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - new_block_ids_np = self.new_block_ids_cpu.numpy() for i in range(self.num_kv_cache_groups): - new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i] - new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True) - - _append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)]( - self.req_indices.copy_to_gpu(num_reqs), - self.cu_num_new_blocks.copy_to_gpu(), - self.cu_num_new_blocks.gpu.stride(0), - new_block_ids_gpu, - new_block_ids_gpu.stride(0), - self.overwrite.copy_to_gpu(num_reqs), - self.block_table_strides, - self.block_table_ptrs, - self.num_blocks, - self.num_blocks.stride(0), - BLOCK_SIZE=1024, # type: ignore - ) + block_ids = new_block_ids[i] + num_new_blocks = len(block_ids) + if num_new_blocks == 0: + continue + + # TODO(woosuk): Too many Numpy invocations. Optimize this. + start = self.num_blocks.np[i, req_index] if not overwrite else 0 + end = start + num_new_blocks + if num_new_blocks == 1: + self.block_tables[i].np[req_index, start] = block_ids[0] + else: + self.block_tables[i].np[req_index, start:end] = block_ids + self.num_blocks.np[i, req_index] = end def gather_block_tables( self, @@ -149,8 +114,8 @@ def gather_block_tables( self.block_table_ptrs, self.input_block_table_ptrs, self.block_table_strides, - self.num_blocks, - self.num_blocks.stride(0), + self.num_blocks.gpu, + self.num_blocks.gpu.stride(0), BLOCK_SIZE=1024, # type: ignore ) return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) @@ -186,54 +151,6 @@ def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor: return self.slot_mappings[:, :num_tokens] -@triton.jit -def _append_block_ids_kernel( - # Inputs - req_indices, # [num_reqs] - cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1] - cu_num_new_blocks_stride, - new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks] - new_block_ids_stride, - overwrite, # [num_reqs] - block_table_strides, # [num_kv_cache_groups] - # Outputs - block_table_ptrs, # [num_kv_cache_groups] - num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] - num_blocks_stride, - # Constants - BLOCK_SIZE: tl.constexpr, -): - group_id = tl.program_id(0) - batch_idx = tl.program_id(1) - req_idx = tl.load(req_indices + batch_idx) - do_overwrite = tl.load(overwrite + batch_idx) - - group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride - start_idx = tl.load(group_new_blocks_ptr + batch_idx) - end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1) - num_new_blocks = end_idx - start_idx - - group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride - dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) if not do_overwrite else 0 - dst_end_idx = dst_start_idx + num_new_blocks - tl.store(group_num_blocks_ptr + req_idx, dst_end_idx) - - # Destination - block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) - block_table_stride = tl.load(block_table_strides + group_id) - row_ptr = block_table_ptr + req_idx * block_table_stride - - group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride - for i in range(0, num_new_blocks, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) - block_ids = tl.load( - group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks - ) - tl.store( - row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks - ) - - @triton.jit def _gather_block_tables_kernel( batch_idx_to_req_idx, # [batch_size] @@ -312,3 +229,10 @@ def _load_ptr(ptr_to_ptr, elem_dtype): ptr = tl.load(ptr_to_ptr) ptr = tl.cast(ptr, tl.pointer_type(elem_dtype)) return tl.multiple_of(ptr, 16) + + +class UvaBuffer: + def __init__(self, *size, dtype: torch.dtype): + self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True) + self.np = self.cpu.numpy() + self.gpu = get_cuda_view_from_cpu_tensor(self.cpu) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 9c170d46590d..4bc75df0c2f0 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -189,7 +189,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: max_num_batched_tokens=self.max_num_tokens, max_model_len=self.max_model_len, device=self.device, - pin_memory=self.pin_memory, ) self.attn_backends, self.attn_metadata_builders = init_attn_backend( @@ -378,16 +377,6 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None: for req_id in scheduler_output.finished_req_ids: self.req_states.remove_request(req_id) - # TODO(woosuk): Change SchedulerOutput. - req_indices: list[int] = [] - cu_num_new_blocks = tuple( - [0] for _ in range(self.block_tables.num_kv_cache_groups) - ) - new_block_ids: tuple[list[int], ...] = tuple( - [] for _ in range(self.block_tables.num_kv_cache_groups) - ) - overwrite: list[bool] = [] - # Add new requests. for new_req_data in scheduler_output.scheduled_new_reqs: assert new_req_data.prompt_token_ids is not None @@ -404,12 +393,9 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None: ) req_index = self.req_states.req_id_to_index[req_id] - req_indices.append(req_index) - for i, block_ids in enumerate(new_req_data.block_ids): - x = cu_num_new_blocks[i][-1] - cu_num_new_blocks[i].append(x + len(block_ids)) - new_block_ids[i].extend(block_ids) - overwrite.append(True) + self.block_tables.append_block_ids( + req_index, new_req_data.block_ids, overwrite=True + ) if scheduler_output.scheduled_new_reqs: self.req_states.prefill_len.copy_to_gpu() @@ -417,23 +403,11 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None: cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): req_index = self.req_states.req_id_to_index[req_id] - req_new_block_ids = cached_reqs.new_block_ids[i] if req_new_block_ids is not None: - req_indices.append(req_index) - for group_id, block_ids in enumerate(req_new_block_ids): - x = cu_num_new_blocks[group_id][-1] - cu_num_new_blocks[group_id].append(x + len(block_ids)) - new_block_ids[group_id].extend(block_ids) - overwrite.append(False) - - if req_indices: - self.block_tables.append_block_ids( - req_indices=req_indices, - cu_num_new_blocks=cu_num_new_blocks, - new_block_ids=new_block_ids, - overwrite=overwrite, - ) + self.block_tables.append_block_ids( + req_index, req_new_block_ids, overwrite=False + ) def prepare_inputs( self,