diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py index de482aec4a42..db851edbccbb 100644 --- a/tests/v1/kv_offload/test_cpu_gpu.py +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -27,6 +27,7 @@ DEVICE_TYPE = current_platform.device_type DEVICES = [f"{DEVICE_TYPE}:0"] NUM_MAPPINGS = [3] +NUM_MAPPINGS_PER_GROUP = [2] @pytest.mark.parametrize("gpu_to_cpu", [True, False]) @@ -58,9 +59,7 @@ def test_transfer( # build CanonicalKVCacheTensor list: one per tensor kv_cache_tensors: list[CanonicalKVCacheTensor] = [] for i in range(num_tensors): - gpu_tensor = torch.randint( - -128, - 127, + gpu_tensor = torch.zeros( (num_gpu_blocks, gpu_page_size_bytes), dtype=torch.int8, device=device, @@ -119,26 +118,36 @@ def test_transfer( for j in range(block_size_factor) ] - # maybe skip some GPU blocks to test reading from the middle of a CPU block - if not gpu_to_cpu: - blocks_to_skip = block_size_factor - 1 + # maybe skip some GPU blocks to test reading/writing from the middle of a CPU block + blocks_to_skip = block_size_factor - 1 + if blocks_to_skip > 0: gpu_blocks = gpu_blocks[blocks_to_skip:] cpu_blocks_expanded = cpu_blocks_expanded[blocks_to_skip:] # set transfer direction if gpu_to_cpu: handler = handlers.gpu_to_cpu_handler - src_spec = GPULoadStoreSpec(gpu_blocks, group_sizes=(len(gpu_blocks),)) + src_spec = GPULoadStoreSpec( + gpu_blocks, group_sizes=(len(gpu_blocks),), block_indices=(blocks_to_skip,) + ) dst_spec = CPULoadStoreSpec(cpu_blocks) dst_to_src = dict(zip(cpu_blocks_expanded, gpu_blocks)) - num_dst_sub_blocks = num_cpu_blocks * block_size_factor + num_dst_sub_blocks = num_gpu_blocks else: handler = handlers.cpu_to_gpu_handler src_spec = CPULoadStoreSpec(cpu_blocks) - dst_spec = GPULoadStoreSpec(gpu_blocks, group_sizes=(len(gpu_blocks),)) + dst_spec = GPULoadStoreSpec( + gpu_blocks, group_sizes=(len(gpu_blocks),), block_indices=(blocks_to_skip,) + ) dst_to_src = dict(zip(gpu_blocks, cpu_blocks_expanded)) num_dst_sub_blocks = num_gpu_blocks + # randomize src and dst tensors before transfer + for tensor in handler.src_tensors: + tensor.random_() + for tensor in handler.dst_tensors: + tensor.random_() + # clone src and dst tensors before transfer orig_src_tensors = [x.clone() for x in handler.src_tensors] orig_dst_tensors = [x.clone() for x in handler.dst_tensors] @@ -146,7 +155,7 @@ def test_transfer( # call transfer function start_time = time.time() assert handler.transfer_async(1, (src_spec, dst_spec)) - assert set({x.job_id for x in handler._transfers}) == {1} + assert {x.job_id for x in handler._transfers} == {1} # wait for transfer to complete end_time = time.time() + 10 @@ -155,11 +164,14 @@ def test_transfer( if finished: assert finished[0].job_id == 1 assert finished[0].success - assert finished[0].transfer_type == ( - ("GPU", "CPU") if gpu_to_cpu else ("CPU", "GPU") + assert ( + finished[0].transfer_type == ("GPU", "CPU") + if gpu_to_cpu + else ("CPU", "GPU") ) assert finished[0].transfer_size == ( - len(gpu_blocks) * handler.group_block_size_in_bytes[0] + len(gpu_blocks) + * sum([x.page_size_bytes for x in handler.kv_cache_groups_data_refs[0]]) ) assert finished[0].transfer_time > 0 assert finished[0].transfer_time < (time.time() - start_time) @@ -196,3 +208,211 @@ def test_transfer( handlers.gpu_to_cpu_handler.shutdown() if mmap_region: mmap_region.cleanup() + + +@pytest.mark.parametrize("gpu_to_cpu", [True, False]) +@pytest.mark.parametrize("num_mappings_per_group", NUM_MAPPINGS_PER_GROUP) +@pytest.mark.parametrize("gpu_page_size_bytes", GPU_PAGE_SIZES) +@pytest.mark.parametrize("block_size_factor", BLOCK_SIZE_FACTORS) +@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS) +@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_transfer_multi_group( + default_vllm_config, + gpu_to_cpu: bool, + num_mappings_per_group: int, + gpu_page_size_bytes: int, + block_size_factor: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + seed: int, + device: str, +) -> None: + """Test transfers with three KV cache groups: + - Group 0: aligned transfer with num_mappings_per_group blocks + - Group 1: zero blocks (empty group) + - Group 2: unaligned CPU->GPU transfer (logical_offset=block_size_factor-1, + causing the implementation to skip source sub-blocks) with + num_mappings_per_group blocks + """ + set_random_seed(seed) + + # 3 groups, each with 2 tensors + num_groups = 3 + tensors_per_group = 2 + num_tensors = num_groups * tensors_per_group + kv_cache_tensors: list[CanonicalKVCacheTensor] = [] + for _ in range(num_tensors): + gpu_tensor = torch.zeros( + (num_gpu_blocks, gpu_page_size_bytes), + dtype=torch.int8, + device=device, + ) + kv_cache_tensors.append( + CanonicalKVCacheTensor( + tensor=gpu_tensor, + page_size_bytes=gpu_page_size_bytes, + ) + ) + + kv_cache_groups_data_refs: list[list[CanonicalKVCacheRef]] = [ + [ + CanonicalKVCacheRef( + tensor_idx=g * tensors_per_group + i, + page_size_bytes=gpu_page_size_bytes, + ) + for i in range(tensors_per_group) + ] + for g in range(num_groups) + ] + + canonical_kv_caches = CanonicalKVCaches( + tensors=kv_cache_tensors, group_data_refs=kv_cache_groups_data_refs + ) + + handlers = CpuGpuOffloadingHandlers( + kv_caches=canonical_kv_caches, + block_size_factor=block_size_factor, + num_cpu_blocks=num_cpu_blocks, + ) + + # group 0: aligned, group 1: empty, group 2: unaligned on CPU->GPU + group_sizes_in_cpu_blocks = [num_mappings_per_group, 0, num_mappings_per_group] + + total_cpu_blocks = sum(group_sizes_in_cpu_blocks) + total_gpu_blocks_needed = total_cpu_blocks * block_size_factor + gpu_blocks_all = random.sample(range(num_gpu_blocks), total_gpu_blocks_needed) + cpu_blocks_all = random.sample(range(num_cpu_blocks), total_cpu_blocks) + + # split gpu/cpu blocks per group + gpu_blocks_per_group: list[list[int]] = [] + cpu_blocks_per_group: list[list[int]] = [] + gpu_offset = 0 + cpu_offset = 0 + for size in group_sizes_in_cpu_blocks: + gpu_count = size * block_size_factor + gpu_blocks_per_group.append(gpu_blocks_all[gpu_offset : gpu_offset + gpu_count]) + cpu_blocks_per_group.append(cpu_blocks_all[cpu_offset : cpu_offset + size]) + gpu_offset += gpu_count + cpu_offset += size + + # expand cpu blocks to gpu-page granularity + cpu_blocks_expanded_per_group = [ + [ + cpu_block * block_size_factor + j + for cpu_block in cpu_blocks + for j in range(block_size_factor) + ] + for cpu_blocks in cpu_blocks_per_group + ] + + # skip sub-blocks from group 2 to test unaligned transfers. + sub_blocks_to_skip = block_size_factor - 1 # e.g. 2 when block_size_factor=3 + if sub_blocks_to_skip > 0: + gpu_blocks_per_group[2] = gpu_blocks_per_group[2][ + sub_blocks_to_skip:-sub_blocks_to_skip + ] + cpu_blocks_expanded_per_group[2] = cpu_blocks_expanded_per_group[2][ + sub_blocks_to_skip:-sub_blocks_to_skip + ] + + # build flat gpu_blocks list and group_sizes in GPU blocks + gpu_blocks: list[int] = [] + group_sizes: list[int] = [] + for gpu_blks in gpu_blocks_per_group: + gpu_blocks.extend(gpu_blks) + group_sizes.append(len(gpu_blks)) + + # build flat cpu_blocks list + cpu_blocks = [] + for cpu_blks in cpu_blocks_per_group: + cpu_blocks.extend(cpu_blks) + + # block_indices: only relevant for unaligned transfers + block_indices: list[int] = [0, 0, sub_blocks_to_skip] + + if gpu_to_cpu: + handler = handlers.gpu_to_cpu_handler + src_spec = GPULoadStoreSpec( + gpu_blocks, group_sizes=group_sizes, block_indices=block_indices + ) + dst_spec = CPULoadStoreSpec(cpu_blocks) + # per-group mapping: cpu sub-block -> gpu sub-block + dst_to_src_per_group = [ + dict(zip(expanded, gpu_blks)) + for expanded, gpu_blks in zip( + cpu_blocks_expanded_per_group, gpu_blocks_per_group + ) + ] + num_dst_sub_blocks = num_cpu_blocks * block_size_factor + else: + handler = handlers.cpu_to_gpu_handler + src_spec = CPULoadStoreSpec(cpu_blocks) + dst_spec = GPULoadStoreSpec( + gpu_blocks, group_sizes=group_sizes, block_indices=block_indices + ) + # per-group mapping: gpu sub-block -> cpu sub-block + dst_to_src_per_group = [ + dict(zip(gpu_blks, expanded)) + for gpu_blks, expanded in zip( + gpu_blocks_per_group, cpu_blocks_expanded_per_group + ) + ] + num_dst_sub_blocks = num_gpu_blocks + + # randomize src and dst tensors before transfer + for tensor in handler.src_tensors: + tensor.random_() + for tensor in handler.dst_tensors: + tensor.random_() + + orig_src_tensors = [x.clone() for x in handler.src_tensors] + orig_dst_tensors = [x.clone() for x in handler.dst_tensors] + + assert handler.transfer_async(1, (src_spec, dst_spec)) + assert {x.job_id for x in handler._transfers} == {1} + + end_time = time.time() + 10 + while time.time() < end_time: + finished = handler.get_finished() + if finished: + assert finished[0].job_id == 1 + assert finished[0].success + expected_bytes = sum( + group_size * sum([x.page_size_bytes for x in data_refs]) + for group_size, data_refs in zip( + group_sizes, handler.kv_cache_groups_data_refs + ) + ) + assert finished[0].transfer_size == expected_bytes + break + time.sleep(0.1) + + # verify src tensors did not change + for orig_tensor, tensor in zip(orig_src_tensors, handler.src_tensors): + assert torch.equal(orig_tensor, tensor) + + # verify dst tensors at gpu-page granularity + for group_idx, dst_to_src in enumerate(dst_to_src_per_group): + group_tensor_offset = group_idx * tensors_per_group + for tensor_idx in range(tensors_per_group): + src_tensor = handler.src_tensors[group_tensor_offset + tensor_idx] + dst_tensor = handler.dst_tensors[group_tensor_offset + tensor_idx] + orig_dst_tensor = orig_dst_tensors[group_tensor_offset + tensor_idx] + src_view = src_tensor.view(-1, gpu_page_size_bytes) + dst_view = dst_tensor.view(-1, gpu_page_size_bytes) + orig_dst_view = orig_dst_tensor.view(-1, gpu_page_size_bytes) + for dst_sub_block in range(num_dst_sub_blocks): + src_sub_block = dst_to_src.get(dst_sub_block) + if src_sub_block is not None: + expected = src_view[src_sub_block] + else: + expected = orig_dst_view[dst_sub_block] + torch.testing.assert_close( + dst_view[dst_sub_block].cpu(), expected.cpu() + ) + + handlers.cpu_to_gpu_handler.shutdown() + handlers.gpu_to_cpu_handler.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index cd5a4f113dc2..bff512815a65 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -381,7 +381,9 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): for i in range(self.config.block_size_factor): src_block_ids.append(block_ids[gpu_block_idx + i]) src_spec = GPULoadStoreSpec( - src_block_ids, group_sizes=(len(src_block_ids),) + src_block_ids, + group_sizes=(len(src_block_ids),), + block_indices=(0,), ) reqs_to_store[req_id] = (src_spec, dst_spec) diff --git a/vllm/v1/kv_offload/mediums.py b/vllm/v1/kv_offload/mediums.py index 85ef2a95a6bd..02e36a80a8e7 100644 --- a/vllm/v1/kv_offload/mediums.py +++ b/vllm/v1/kv_offload/mediums.py @@ -34,26 +34,24 @@ class GPULoadStoreSpec(BlockIDsLoadStoreSpec): will correspond to logically contiguous blocks, e.g. blocks 5-10 of a some request. block_indices[i] will represent the block index of the first block in group #i. Thus, len(block_indices) == len(group_sizes) = number of KV cache groups. - This information is required in order to support loading from offloaded blocks + This information is required in order to support off/loading from offloaded blocks which are larger than GPU blocks. In such cases, the first GPU block per each group may be unaligned to the offloaded block size, and so knowing block_indices[i] allows the worker to correctly skip part of the first matching offloaded block. - Offloading from GPU is always aligned to offloaded block size, and so - block_indices will only be set by the offloading connector when loading into GPU. """ def __init__( self, block_ids: list[int], group_sizes: Sequence[int], - block_indices: Sequence[int] | None = None, + block_indices: Sequence[int], ): super().__init__(block_ids) assert sum(group_sizes) == len(block_ids) - assert block_indices is None or len(block_indices) == len(group_sizes) + assert len(block_indices) == len(group_sizes) self.group_sizes: Sequence[int] = group_sizes - self.block_indices: Sequence[int] | None = block_indices + self.block_indices: Sequence[int] = block_indices @staticmethod def medium() -> str: diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index dd12a533ede9..aab57ef2be4d 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -9,9 +9,10 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.kv_offload.cpu.shared_offload_region import SharedOffloadRegion -from vllm.v1.kv_offload.mediums import BlockIDsLoadStoreSpec +from vllm.v1.kv_offload.mediums import BlockIDsLoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.spec import CanonicalKVCacheRef, CanonicalKVCaches from vllm.v1.kv_offload.worker.worker import ( OffloadingHandler, @@ -135,9 +136,6 @@ def __init__( assert len(gpu_tensors) == len(cpu_tensors) assert len(gpu_tensors) > 0 - # assert a single KV group until transfer_async supports multiple groups - assert len(kv_cache_groups_data_refs) == 1 - # assert input tensors are as expected for gpu_tensor, cpu_tensor in zip(gpu_tensors, cpu_tensors): assert gpu_tensor.dtype == torch.int8 @@ -157,29 +155,13 @@ def __init__( cpu_tensors if gpu_to_cpu else gpu_tensors ) self.gpu_to_cpu: bool = gpu_to_cpu + self.kv_cache_groups_data_refs = kv_cache_groups_data_refs # GPU blocks may be smaller # cpu_page_size = gpu_page_size * block_size_factor. self.src_block_size_factor = 1 if self.gpu_to_cpu else block_size_factor self.dst_block_size_factor = block_size_factor if self.gpu_to_cpu else 1 - # per-tensor block size in byte - self.tensor_block_size_in_bytes = [ - gpu_tensor.shape[1] for gpu_tensor in gpu_tensors - ] - - # per-group block size in bytes - self.group_block_size_in_bytes = [] - for kv_cache_group_data_refs in kv_cache_groups_data_refs: - group_block_size_in_bytes = 0 - for kv_cache_data_ref in kv_cache_group_data_refs: - # TODO(orozery): use kv_cache_data_ref.page_size_bytes - # once swap_blocks support it - group_block_size_in_bytes += self.tensor_block_size_in_bytes[ - kv_cache_data_ref.tensor_idx - ] - self.group_block_size_in_bytes.append(group_block_size_in_bytes) - self.transfer_type = ("GPU", "CPU") if self.gpu_to_cpu else ("CPU", "GPU") # job_id -> event self._transfer_events: dict[int, torch.Event] = {} @@ -190,11 +172,6 @@ def __init__( # list of CUDA events available for re-use self._event_pool: list[torch.Event] = [] - # Pre-compute block sizes for batch copies. - self._block_size_in_bytes_arr = np.array( - self.tensor_block_size_in_bytes, dtype=np.int64 - ) - def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: src_spec, dst_spec = transfer_spec assert isinstance(src_spec, BlockIDsLoadStoreSpec) @@ -205,37 +182,108 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: assert src_blocks.ndim == 1 assert dst_blocks.ndim == 1 - src_sub_block_count = src_blocks.size * self.src_block_size_factor - dst_sub_block_count = dst_blocks.size * self.dst_block_size_factor - src_sub_blocks_to_skip = -dst_blocks.size % self.src_block_size_factor - - assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip - - num_pairs = dst_sub_block_count - num_tensors = len(self.src_tensors) - total = num_pairs * num_tensors - - all_src = np.empty(total, dtype=np.int64) - all_dst = np.empty(total, dtype=np.int64) - all_sizes = np.empty(total, dtype=np.int64) - - for t_idx, bsz in enumerate(self._block_size_in_bytes_arr): - start = t_idx * num_pairs - end = start + num_pairs - compute_sub_block_ptrs( - block_ids=src_blocks, - block_size_factor=self.src_block_size_factor, - output=all_src[start:end], - tensor=self.src_tensors[t_idx], - skip_count=src_sub_blocks_to_skip, + num_src_blocks = len(src_blocks) + num_dst_blocks = len(dst_blocks) + + # There are 2 types of transfers: + # 1. GPU -> CPU + # 2. CPU -> GPU + # + # transfers are also to CPU blocks, EXCEPT MAYBE for the first and last block. + # i.e. the first and last CPU blocks in src_blocks can match against + # a smaller (byte-wise) set of GPU blocks in dst_blocks. + # In such cases, we may need to skip some gpu-sized sub-blocks, + # and start reading/writing from the middle of the first CPU block. + # If we have multiple KV cache groups (when using HMA with hybrid models), + # we may have a partial first/last CPU block per each group. + # The group_sizes parameter encodes the size of each group of blocks + # in the GPU dst_blocks. + # If group_sizes is None, we assume all blocks belong to a single group. + # The logical_offset parameter maps each group of blocks to its logical + # offset inside the request, counting in GPU blocks. + # This allows us to find the correct starting position + # in the matching first CPU block. + + # extract group_sizes from the GPU spec + gpu_spec = src_spec if self.gpu_to_cpu else dst_spec + assert isinstance(gpu_spec, GPULoadStoreSpec) + group_sizes = gpu_spec.group_sizes + assert len(group_sizes) == len(self.kv_cache_groups_data_refs) + + # extract block indices from the GPU spec + block_indices = gpu_spec.block_indices + assert len(block_indices) == len(self.kv_cache_groups_data_refs) + + num_copy_ops = 0 + for group_size, group_data_refs in zip( + group_sizes, self.kv_cache_groups_data_refs + ): + num_copy_ops += group_size * len(group_data_refs) + + all_src = np.empty(num_copy_ops, dtype=np.int64) + all_dst = np.empty(num_copy_ops, dtype=np.int64) + all_sizes = np.empty(num_copy_ops, dtype=np.int64) + + src_offset = 0 + dst_offset = 0 + op_idx = 0 + # count total number of bytes copied + num_transfer_bytes = 0 + for group_size, block_idx, group_data_refs in zip( + group_sizes, block_indices, self.kv_cache_groups_data_refs + ): + if group_size == 0: + continue + + src_logical_blocks_to_skip = block_idx % self.src_block_size_factor + dst_logical_blocks_to_skip = block_idx % self.dst_block_size_factor + src_logical_blocks_count = group_size + src_logical_blocks_to_skip + dst_logical_blocks_count = group_size + dst_logical_blocks_to_skip + + dst_blocks_count = cdiv( + dst_logical_blocks_count, self.dst_block_size_factor ) - compute_sub_block_ptrs( - block_ids=dst_blocks, - block_size_factor=self.dst_block_size_factor, - output=all_dst[start:end], - tensor=self.dst_tensors[t_idx], + dst_end_offset = dst_offset + dst_blocks_count + assert dst_end_offset <= num_dst_blocks + + src_blocks_count = cdiv( + src_logical_blocks_count, self.src_block_size_factor ) - all_sizes[start:end] = bsz + src_end_offset = src_offset + src_blocks_count + assert src_end_offset <= num_src_blocks + + group_src = src_blocks[src_offset:src_end_offset] + group_dst = dst_blocks[dst_offset:dst_end_offset] + + for data_ref in group_data_refs: + t_idx = data_ref.tensor_idx + end_idx = op_idx + group_size + + compute_sub_block_ptrs( + group_src, + self.src_block_size_factor, + all_src[op_idx:end_idx], + self.src_tensors[t_idx], + skip_count=src_logical_blocks_to_skip, + ) + compute_sub_block_ptrs( + group_dst, + self.dst_block_size_factor, + all_dst[op_idx:end_idx], + self.dst_tensors[t_idx], + skip_count=dst_logical_blocks_to_skip, + ) + + all_sizes[op_idx:end_idx] = data_ref.page_size_bytes + num_transfer_bytes += group_size * data_ref.page_size_bytes + op_idx = end_idx + + src_offset = src_end_offset + dst_offset = dst_end_offset + + assert src_offset == num_src_blocks + assert dst_offset == num_dst_blocks + assert op_idx == num_copy_ops batch_src = torch.from_numpy(all_src) batch_dst = torch.from_numpy(all_dst) @@ -263,7 +311,7 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: stream.wait_event(last_event) with torch.cuda.stream(stream): start_event.record(stream) - if total > 0: + if num_copy_ops > 0: ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes) end_event.record(stream) @@ -274,7 +322,7 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: stream=stream, start_event=start_event, end_event=end_event, - num_bytes=dst_sub_block_count * self.group_block_size_in_bytes[0], + num_bytes=num_transfer_bytes, ) )