diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 37d6993abe67..ba40e8e45378 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5550,16 +5550,14 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: kv_cache_spec = self.get_kv_cache_spec() kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec) min_blocks = self.compilation_config.max_cudagraph_capture_size or 1 - if kv_cache_groups: - page_size = kv_cache_groups[0].kv_cache_spec.page_size_bytes - group_size = max(len(g.layer_names) for g in kv_cache_groups) - available_memory = min_blocks * page_size * group_size - else: - available_memory = 1 # Attention-free model + # Temporarily change num_gpu_blocks_override to allocate a minimal KV cache + saved_override = self.cache_config.num_gpu_blocks_override + self.cache_config.num_gpu_blocks_override = min_blocks minimal_config = get_kv_cache_config_from_groups( - self.vllm_config, kv_cache_groups, available_memory=available_memory + self.vllm_config, kv_cache_groups, available_memory=0 ) + self.cache_config.num_gpu_blocks_override = saved_override self.initialize_kv_cache(minimal_config) self.cache_config.num_gpu_blocks = minimal_config.num_blocks