Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/entrypoints/serve/instrumentator/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,17 @@ async def test_metrics_exist(
continue
assert metric in response.text

cache_config_samples = [
sample
for family in text_string_to_metric_families(response.text)
if family.name == "vllm:cache_config_info"
for sample in family.samples
]
assert cache_config_samples
for sample in cache_config_samples:
assert sample.labels.get("kv_cache_size_tokens") not in (None, "None", "")
assert sample.labels.get("kv_cache_max_concurrency") not in (None, "None", "")


@pytest.mark.asyncio
async def test_abort_metrics_reset(
Expand Down
6 changes: 6 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
estimate_max_model_len,
generate_block_hash_extra_keys,
generate_scheduler_kv_cache_config,
get_kv_cache_capacity,
get_kv_cache_configs,
get_max_concurrency_for_kv_cache_config,
get_request_block_hasher,
Expand Down Expand Up @@ -1459,6 +1460,11 @@ def test_get_max_concurrency_for_kv_cache_config():
vllm_config, kv_cache_config_hybrid_model
)
assert max_concurrency_hybrid_model == 3
num_tokens, max_concurrency = get_kv_cache_capacity(
vllm_config, kv_cache_config_hybrid_model
)
assert num_tokens == max_concurrency_hybrid_model * max_model_len
assert max_concurrency == max_concurrency_hybrid_model


def test_allocate_with_lookahead():
Expand Down
10 changes: 10 additions & 0 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ class CacheConfig:
num_cpu_blocks: int | None = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""

# Set after KV cache initialization.
kv_cache_size_tokens: int | None = field(default=None, init=False)
"""Per-DP-engine KV cache capacity in tokens (group-aware). Uses
group-aware capacity since num_gpu_blocks * block_size can be wrong
for hybrid models where requests occupy multiple KV cache groups."""
kv_cache_max_concurrency: float | None = field(default=None, init=False)
"""Per-DP-engine maximum concurrency at max_model_len tokens."""

kv_sharing_fast_prefill: bool = False
"""This feature is work in progress and no prefill optimization takes place
with this flag enabled currently.
Expand Down Expand Up @@ -204,6 +212,8 @@ def compute_hash(self) -> str:
# Post-init/derived counters
"num_gpu_blocks",
"num_cpu_blocks",
"kv_cache_size_tokens",
"kv_cache_max_concurrency",
# WIP feature toggle not impacting compiled graph shape
"kv_sharing_fast_prefill",
}
Expand Down
43 changes: 19 additions & 24 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,36 +1717,17 @@ def generate_scheduler_kv_cache_config(
return cfg


def _report_kv_cache_config(
def get_kv_cache_capacity(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
) -> None:
) -> tuple[int, float]:
"""
Log resolved KV cache configuration.

Args:
vllm_config: The global VllmConfig
kv_cache_config: The resolved KV cache configuration
Get the group-aware KV cache token capacity and max concurrency.
"""
max_model_len = vllm_config.model_config.max_model_len
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config
)

# GPU KV cache size in tokens = max_concurrency * max_model_len: the total
# tokens of context the pool can hold at peak utilization. Sourcing this
# from the concurrency calculation handles hybrid layouts correctly: SWA /
# chunked-local groups have a per-request block count that's capped by
# their window, so a naive `num_blocks // num_groups * block_size` formula
# underestimates capacity for these models. DCP/PCP sharding is already
# accounted for in each spec's `max_memory_usage_bytes`.
num_tokens = int(max_concurrency * max_model_len)

logger.info_once("GPU KV cache size: %s tokens", f"{num_tokens:,}")
logger.info_once(
"Maximum concurrency for %s tokens per request: %.2fx",
f"{max_model_len:,}",
max_concurrency,
)
return int(max_concurrency * max_model_len), max_concurrency


def _max_memory_usage_bytes_from_groups(
Expand Down Expand Up @@ -2085,7 +2066,21 @@ def get_kv_cache_configs(
tensor.size = tensor.size // num_blocks_old * min_num_blocks

if len(kv_cache_config.kv_cache_groups) > 0:
_report_kv_cache_config(vllm_config, kv_cache_config)
max_model_len = vllm_config.model_config.max_model_len
# GPU KV cache size in tokens = max_concurrency * max_model_len:
# the total tokens of context the pool can hold at peak
# utilization. Sourcing this from the concurrency calculation
# handles hybrid layouts correctly.
num_tokens, max_concurrency = get_kv_cache_capacity(
vllm_config, kv_cache_config
)

logger.info_once("GPU KV cache size: %s tokens", f"{num_tokens:,}")
logger.info_once(
"Maximum concurrency for %s tokens per request: %.2fx",
f"{max_model_len:,}",
max_concurrency,
)

return kv_cache_configs

Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class EngineCoreReadyResponse:
dp_stats_address: str | None
dtype: str
vllm_version: str
# KV cache capacity (None for encoder-only/attention-free models).
kv_cache_size_tokens: int | None = None
kv_cache_max_concurrency: float | None = None


class EngineCoreRequest(
Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from vllm.v1.core.kv_cache_utils import (
BlockHash,
generate_scheduler_kv_cache_config,
get_kv_cache_capacity,
get_kv_cache_configs,
get_request_block_hasher,
init_none_hash,
Expand Down Expand Up @@ -286,6 +287,11 @@ def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
vllm_config.cache_config.block_size = min(
g.kv_cache_spec.block_size for g in kv_cache_groups
)
num_tokens, max_concurrency = get_kv_cache_capacity(
vllm_config, scheduler_kv_cache_config
)
vllm_config.cache_config.kv_cache_size_tokens = num_tokens
vllm_config.cache_config.kv_cache_max_concurrency = max_concurrency

vllm_config.validate_block_size()

Expand Down Expand Up @@ -1494,6 +1500,12 @@ def process_input_sockets(
dp_stats_address=self.frontend_stats_publish_address,
dtype=str(self.vllm_config.model_config.dtype).removeprefix("torch."),
vllm_version=VLLM_VERSION,
kv_cache_size_tokens=(
self.vllm_config.cache_config.kv_cache_size_tokens
),
kv_cache_max_concurrency=(
self.vllm_config.cache_config.kv_cache_max_concurrency
),
)
ready_payload = msgspec.msgpack.encode(ready_response)
for input_socket in input_sockets:
Expand Down
16 changes: 14 additions & 2 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,14 +720,26 @@ def _apply_ready_response(self, payload: bytes) -> None:
)

# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
# engine core process. Sum num_gpu_blocks from all engines in DP case.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks or 0
num_gpu_blocks += response.num_gpu_blocks
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks

# Sync block_size: may be enlarged by _align_hybrid_block_size in the
# worker for hybrid Mamba models.
vllm_config.cache_config.block_size = response.block_size
cache_config = vllm_config.cache_config
cache_config.block_size = response.block_size
# Keep these as per-engine cache_config_info values; do not sum across DP.
cache_config.kv_cache_size_tokens = (
getattr(cache_config, "kv_cache_size_tokens", None)
if getattr(cache_config, "kv_cache_size_tokens", None) is not None
else response.kv_cache_size_tokens
)
cache_config.kv_cache_max_concurrency = (
getattr(cache_config, "kv_cache_max_concurrency", None)
if getattr(cache_config, "kv_cache_max_concurrency", None) is not None
else response.kv_cache_max_concurrency
)

# In external DP LB mode, the coordinator address that the
# front-end procs connect to is obtained by each engine via it's
Expand Down
Loading