diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f73431fd7b2f..b184f6574130 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -322,7 +322,7 @@ def __post_init__(self): # Figure out whether the first dimension of the cache is K/V # or num_blocks. This is used to register the memory regions correctly. kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=4, head_size=1 + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # we just mock num_blocks to 1 for the dimension check below.