diff --git a/tests/compile/h100/test_startup.py b/tests/compile/h100/test_startup.py index ff4496c2ba6d..78554a3e93da 100644 --- a/tests/compile/h100/test_startup.py +++ b/tests/compile/h100/test_startup.py @@ -34,7 +34,10 @@ def _run_vllm(vllm_runner): mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=CUDAGraphMode.NONE, ), - num_gpu_blocks_override=8, + # Phi-tiny-MoE uses SWA, whose admission cap is `cdiv(L, block_size) + 1` + # at default block_size=16 — i.e. 17 blocks for max_model_len=256. Use + # 32 for headroom. + num_gpu_blocks_override=32, ): pass @@ -190,7 +193,7 @@ def _run_model(vllm_runner, spec: ModelStartupSpec): cudagraph_mode=CUDAGraphMode.NONE, pass_config=PassConfig(fuse_allreduce_rms=False), ), - num_gpu_blocks_override=8, + num_gpu_blocks_override=16, ): pass diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index cfd03c5f687e..985b97c69ca4 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2074,6 +2074,54 @@ def test_auto_fit_max_model_len_not_triggered(): assert vllm_config.model_config.max_model_len == 16 +def test_auto_fit_max_model_len_respects_num_gpu_blocks_override(): + """Auto-fit must size max_model_len against the override-clamped pool, not + the raw `available_memory`. Without this, auto-fit could pick a + max_model_len that no longer fits once `num_gpu_blocks_override` is applied. + """ + model_config = ModelConfig(max_model_len=16384) + model_config.original_max_model_len = -1 # request auto-fit + vllm_config = VllmConfig(model_config=model_config) + # Cap the cache to 32 blocks regardless of available memory. + vllm_config.cache_config.num_gpu_blocks_override = 32 + + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 + kv_cache_specs = { + "layer_1": new_kv_cache_spec(), # block_size=16 + "layer_2": new_kv_cache_spec(), + } + # Plenty of raw memory (1024 blocks per layer would fit max_model_len=16384). + large_available_memory = mem_per_block_per_layer * 2 * 1024 + + get_kv_cache_configs(vllm_config, [kv_cache_specs], [large_available_memory]) + + # 32 blocks * block_size 16 = 512 token slots, so max_model_len must + # auto-fit at or below that. + assert 0 < vllm_config.model_config.max_model_len <= 32 * 16 + + +def test_check_enough_kv_cache_memory_respects_num_gpu_blocks_override(): + """Admission check must use the override-clamped pool size, not raw + `available_memory`. Without this, startup could accept a max_model_len + that does not actually fit in `num_gpu_blocks_override` blocks. + """ + model_config = ModelConfig(max_model_len=16384) + vllm_config = VllmConfig(model_config=model_config) + # 32 blocks is far too small for max_model_len=16384 (would need 1024). + vllm_config.cache_config.num_gpu_blocks_override = 32 + + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 + kv_cache_specs = { + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), + } + # Plenty of raw memory: a bytes-only check against this would pass. + large_available_memory = mem_per_block_per_layer * 2 * 1024 + + with pytest.raises(ValueError, match="max seq len"): + get_kv_cache_configs(vllm_config, [kv_cache_specs], [large_available_memory]) + + def test_unify_hybrid_kv_cache_specs(): # 1. has_full_attention and has_sliding_window before_spec_1 = new_kv_cache_spec() diff --git a/tests/v1/e2e/general/test_async_scheduling.py b/tests/v1/e2e/general/test_async_scheduling.py index 8e1eddb0f64e..28a1bedbe0b2 100644 --- a/tests/v1/e2e/general/test_async_scheduling.py +++ b/tests/v1/e2e/general/test_async_scheduling.py @@ -324,10 +324,13 @@ def run_test( ): spec_decoding = spec_config is not None cache_arg: dict[str, Any] = ( - # Force preemptions - dict(num_gpu_blocks_override=32) + # Force preemptions: with 32 blocks the cache holds at most a single + # max-length request, so the ~34 concurrent prompts contend and trigger + # preemption. (Prompts here are << max_model_len, so dropping + # max_model_len from 4096 to 512 doesn't change generation behavior.) + dict(num_gpu_blocks_override=32, max_model_len=512) if test_preemption - else dict(gpu_memory_utilization=0.9) + else dict(gpu_memory_utilization=0.9, max_model_len=4096) ) spec_mml = (spec_config or {}).get("max_model_len") spec_method = (spec_config or {}).get("method", "none") @@ -343,7 +346,6 @@ def run_test( with VllmRunner( model, - max_model_len=4096, enable_chunked_prefill=test_prefill_chunking, # Force prefill chunking max_num_batched_tokens=48 if test_prefill_chunking else None, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3e0e7fcb8c5b..b57e10b67faa 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -890,31 +890,48 @@ def get_max_concurrency_for_kv_cache_config( return max_concurrency -def may_override_num_blocks( - vllm_config: VllmConfig, num_blocks: int, suppress_log: bool = False -) -> int: +def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: """ Override the number of kv cache blocks if `num_gpu_blocks_override` is set. + The override is logged once, at the call site in `get_kv_cache_configs`. """ if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = vllm_config.cache_config.num_gpu_blocks_override - if not suppress_log: - logger.info( - "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", - num_blocks, - num_gpu_blocks_override, - ) - num_blocks = num_gpu_blocks_override - + num_blocks = vllm_config.cache_config.num_gpu_blocks_override return num_blocks +def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int: + """ + Bytes consumed by one block in the worker's shared KV cache pool, mirroring + the divisor used by `get_kv_cache_config_from_groups` to convert + `available_memory` into `num_blocks`. Used to compute the effective KV cache + capacity once `num_gpu_blocks_override` is applied. + """ + if len(kv_cache_groups) == 1 and isinstance( + kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs + ): + return kv_cache_groups[0].kv_cache_spec.page_size_bytes + if all( + isinstance(g.kv_cache_spec, UniformTypeKVCacheSpecs) for g in kv_cache_groups + ): + # DeepseekV4: shared layout sized by the largest per-page-size bucket. + full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec) + layer_tuple_page_bytes = sum(full_mla_spec.get_page_sizes()) + num_layer_tuples = max( + cast(UniformTypeKVCacheSpecs, g.kv_cache_spec).get_num_layer_tuples() + for g in kv_cache_groups + ) + return layer_tuple_page_bytes * num_layer_tuples + group_size = max(len(g.layer_names) for g in kv_cache_groups) + page_size = get_uniform_page_size([g.kv_cache_spec for g in kv_cache_groups]) + return page_size * group_size + + def get_num_blocks( vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int, - suppress_log: bool = False, ) -> int: """ Get the number of kv cache blocks. @@ -924,15 +941,10 @@ def get_num_blocks( num_layers: The number of layers available_memory: Memory available for KV cache in bytes. page_size: The page size of the KV cache. - suppress_log: Whether to suppress override log messages. Used when creating a - temporary/dummy KV cache config, e.g. during CG memory profiling """ num_blocks = int(available_memory // page_size // num_layers) num_blocks = max(num_blocks, 0) - num_blocks = may_override_num_blocks( - vllm_config, num_blocks, suppress_log=suppress_log - ) - return num_blocks + return may_override_num_blocks(vllm_config, num_blocks) def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int: @@ -1220,7 +1232,6 @@ def get_kv_cache_config_from_groups( vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec], available_memory: int, - suppress_log: bool = False, ) -> KVCacheConfig: """ Generate the KV cache configuration from the KV cache groups and spec @@ -1252,9 +1263,7 @@ def get_kv_cache_config_from_groups( num_blocks = ( available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes ) - num_blocks = may_override_num_blocks( - vllm_config, num_blocks, suppress_log=suppress_log - ) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs kv_cache_tensors = [ KVCacheTensor( @@ -1288,11 +1297,7 @@ def get_kv_cache_config_from_groups( ) assert group_size > 0, "group_size must be greater than 0" num_blocks = get_num_blocks( - vllm_config, - group_size, - available_memory, - page_size, - suppress_log=suppress_log, + vllm_config, group_size, available_memory, page_size ) kv_cache_tensors = [] for i in range(group_size): @@ -1688,36 +1693,24 @@ def _report_kv_cache_config( vllm_config: The global VllmConfig kv_cache_config: The resolved KV cache configuration """ - min_block_size = min( - [group.kv_cache_spec.block_size for group in kv_cache_config.kv_cache_groups] - ) - - # Log the KV cache size and maximum concurrency. - num_tokens = ( - kv_cache_config.num_blocks - // len(kv_cache_config.kv_cache_groups) - * min_block_size - ) - dcp_size = vllm_config.parallel_config.decode_context_parallel_size - pcp_size = vllm_config.parallel_config.prefill_context_parallel_size - if pcp_size * dcp_size > 1: - num_tokens *= pcp_size * dcp_size - logger.info( - "Multiplying the GPU KV cache size by the cp_world_size %d " - "(pcp_world_size %d * dcp_world_size %d).", - pcp_size * dcp_size, - pcp_size, - dcp_size, - ) - num_tokens_str = f"{num_tokens:,}" - logger.info_once("GPU KV cache size: %s tokens", num_tokens_str) - max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_model_len = vllm_config.model_config.max_model_len max_concurrency = get_max_concurrency_for_kv_cache_config( vllm_config, kv_cache_config ) + + # GPU KV cache size in tokens = max_concurrency * max_model_len: the total + # tokens of context the pool can hold at peak utilization. Sourcing this + # from the concurrency calculation handles hybrid layouts correctly: SWA / + # chunked-local groups have a per-request block count that's capped by + # their window, so a naive `num_blocks // num_groups * block_size` formula + # underestimates capacity for these models. DCP/PCP sharding is already + # accounted for in each spec's `max_memory_usage_bytes`. + num_tokens = int(max_concurrency * max_model_len) + + logger.info_once("GPU KV cache size: %s tokens", f"{num_tokens:,}") logger.info_once( "Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, + f"{max_model_len:,}", max_concurrency, ) @@ -1988,6 +1981,28 @@ def get_kv_cache_configs( for worker_spec in kv_cache_specs ] + # If `num_gpu_blocks_override` is set, the cache size that will actually + # be allocated is decoupled from the profiled `available_memory`: + # `may_override_num_blocks` in `get_kv_cache_config_from_groups` clamps + # `num_blocks` to the override. Reflect that in `available_memory` here so + # auto-fit, the admission check, and the per-worker config builder all + # plan against the same effective capacity. + override = vllm_config.cache_config.num_gpu_blocks_override + if override is not None: + adjusted_memory: list[int] = [] + for groups, avail_mem in zip(projected_groups_per_worker, available_memory): + if not groups: + adjusted_memory.append(avail_mem) + continue + bytes_per_block = _pool_bytes_per_block(groups) + logger.info( + "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", + avail_mem // bytes_per_block, + override, + ) + adjusted_memory.append(override * bytes_per_block) + available_memory = adjusted_memory + if vllm_config.model_config.original_max_model_len == -1: _auto_fit_max_model_len( vllm_config, projected_groups_per_worker, available_memory diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a0ba47f945a7..caf3bfdfc3a8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5874,7 +5874,7 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: saved_override = self.cache_config.num_gpu_blocks_override self.cache_config.num_gpu_blocks_override = min_blocks minimal_config = get_kv_cache_config_from_groups( - self.vllm_config, kv_cache_groups, available_memory=0, suppress_log=True + self.vllm_config, kv_cache_groups, available_memory=0 ) self.cache_config.num_gpu_blocks_override = saved_override