diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9ab5af0f6fb0..623ff502ac07 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -799,25 +799,53 @@ def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: return True -def get_max_concurrency_for_kv_cache_config( +def _blocks_per_request( vllm_config: VllmConfig, kv_cache_config: KVCacheConfig -) -> float: - """ - Get the maximum concurrency for the given KV cache configuration. +) -> int: + """Return number of blocks needed per request at max_model_len. + + Note: the num_layer_per_group factor appears in both numerator and + denominator and cancels out, so the result is correct regardless of + whether page_size_bytes already includes all layers (as in + UniformTypeKVCacheSpecs) or is per-layer. """ num_layer_per_group = max( len(group.layer_names) for group in kv_cache_config.kv_cache_groups ) + page_size = kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( - vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups) + vllm_config, + (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups), ) - memory_per_block = ( - kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes - * num_layer_per_group + memory_per_block = page_size * num_layer_per_group + return cdiv(max_memory_usage_per_request, memory_per_block) + + +def _kv_cache_bytes_per_block(kv_cache_config: KVCacheConfig) -> int: + """Return the actual memory footprint of one block across all layers. + + For UniformTypeKVCacheSpecs, page_size_bytes already sums across all + layers in the group. For other spec types, page_size_bytes is per-layer + and must be multiplied by the layer count. + """ + spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + return spec.page_size_bytes + num_layers = max( + len(g.layer_names) for g in kv_cache_config.kv_cache_groups + ) + return spec.page_size_bytes * num_layers + + +def get_max_concurrency_for_kv_cache_config( + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig +) -> float: + """ + Get the maximum concurrency for the given KV cache configuration. + """ + return kv_cache_config.num_blocks / _blocks_per_request( + vllm_config, kv_cache_config ) - num_block_per_request = cdiv(max_memory_usage_per_request, memory_per_block) - max_concurrency = kv_cache_config.num_blocks / num_block_per_request - return max_concurrency def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: @@ -1328,6 +1356,36 @@ def _report_kv_cache_config( scope="local", ) + # Log KV cache memory in GiB and workload capacity. + blocks_per_req = _blocks_per_request(vllm_config, kv_cache_config) + bytes_per_block = _kv_cache_bytes_per_block(kv_cache_config) + allocated_kv_bytes = kv_cache_config.num_blocks * bytes_per_block + logger.info_once( + "GPU KV cache memory: %s GiB (%d blocks)", + format_gib(allocated_kv_bytes), + kv_cache_config.num_blocks, + scope="local", + ) + + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + needed_kv_bytes = blocks_per_req * max_num_seqs * bytes_per_block + logger.info_once( + "KV cache for %d seqs x %s tokens: %s GiB (allocated: %s GiB)", + max_num_seqs, + max_model_len_str, + format_gib(needed_kv_bytes), + format_gib(allocated_kv_bytes), + scope="local", + ) + if max_concurrency < max_num_seqs: + logger.warning( + "KV cache can hold %d full-length sequences but " + "max_num_seqs is %d. " + "Sequences will queue when KV cache is full.", + int(max_concurrency), + max_num_seqs, + ) + def _max_memory_usage_bytes_from_groups( vllm_config: VllmConfig,