diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 8e5bb11e4da..34089a67b3b 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -115,9 +115,12 @@ def _reshape_kv_cache( ) -> dict[str, torch.Tensor]: kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group_spec in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - assert isinstance(kv_cache_spec, AttentionSpec) for layer_name in kv_cache_group_spec.layer_names: + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + kv_cache_spec = kv_cache_spec.kv_cache_specs[layer_name] + assert isinstance(kv_cache_spec, AttentionSpec) + raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes