diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py index 6457a599ab91..220c27c7674d 100644 --- a/benchmarks/attention_benchmarks/runner.py +++ b/benchmarks/attention_benchmarks/runner.py @@ -333,11 +333,8 @@ def _create_kv_cache( ) # Get the stride order for custom memory layout - try: - stride_order = backend_class.get_kv_cache_stride_order() - assert len(stride_order) == len(cache_shape) - except (AttributeError, NotImplementedError): - stride_order = tuple(range(len(cache_shape))) + stride_order = backend_class.get_kv_cache_stride_order() + assert len(stride_order) == len(cache_shape) # Permute shape to physical layout order physical_shape = tuple(cache_shape[i] for i in stride_order) diff --git a/tests/compile/passes/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py index ffa01563ef98..3b64c36df8dc 100644 --- a/tests/compile/passes/test_fusion_attn.py +++ b/tests/compile/passes/test_fusion_attn.py @@ -108,10 +108,8 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size ) - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) inv_order = [ diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f9367da73710..897637b1cdf8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -342,12 +342,9 @@ def __post_init__(self): # prepend layers dimension _MOCK_NUM_LAYERS = 80 kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape - try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=self._cross_layers_blocks - ) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(self.tensor_shape))) + kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=self._cross_layers_blocks + ) # In case of cross layers permute kv_cache_shape according to # stride_order to retrieve physical position of block_size diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 9c004d7724dd..3000ff9cd433 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -98,8 +98,10 @@ def get_kv_cache_stride_order( ordering of dimensions is [num_blocks, num_heads, 2, block_size, head_size]. - If this function is unimplemented / raises NotImplementedError, - the physical layout of the KV cache will match the logical shape. + The default implementation returns the identity permutation (physical + layout matches the logical shape) for the standard 5-dimensional + KV cache shape. Backends that need a custom memory layout should + override this method. Args: include_num_layers_dimension: if True, includes an additional @@ -115,7 +117,11 @@ def get_kv_cache_stride_order( Returns: A tuple of ints which is a permutation of range(len(shape)). """ - raise NotImplementedError + # Standard KV cache has 5 dimensions. + num_dims = 5 + if include_num_layers_dimension: + num_dims += 1 + return tuple(range(num_dims)) @classmethod def full_cls_name(cls) -> tuple[str, str]: diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index a5abae51ef03..702e3fdd002d 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -259,13 +259,10 @@ def __init__( assert gpu_shape[0] == 2 split_k_and_v = True - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=has_layers_dim - ) - assert len(kv_cache_stride_order) == len(gpu_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(gpu_shape))) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=has_layers_dim + ) + assert len(kv_cache_stride_order) == len(gpu_shape) # permute test_shape according to stride_order test_shape = tuple(test_shape[i] for i in kv_cache_stride_order) diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 8a08fba1e44a..a3fb3f3678e2 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -99,12 +99,8 @@ def _reshape_kv_cache( kv_cache_spec.head_size, ) - # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) inv_order = [ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 41ec062305b5..1ea68c386c56 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5877,11 +5877,8 @@ def _reshape_kv_cache_tensors( cache_dtype_str=self.cache_config.cache_dtype, ) dtype = kv_cache_spec.dtype - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) # The allocation respects the backend-defined stride order # to ensure the semantic remains consistent for each # backend. We first obtain the generic kv cache shape and diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 0556c3e6e41c..c930256d6436 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -127,7 +127,7 @@ def use_uniform_kv_cache( have the same page size. 2. A KV connector is configured, and the KV connector instance prefers to use this layout (prefer_cross_layer_blocks() returns True) - 2. The flash attention backend supports this layout + 3. The attention backend supports this layout (get_kv_cache_stride_order(True) includes a placement for a num_layers dimension) @@ -166,12 +166,9 @@ def use_uniform_kv_cache( cache_dtype_str=cache_dtype, ) - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True - ) - except (AttributeError, NotImplementedError): - return False + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=True + ) # check that attention backend include a layers dimension return len(kv_cache_stride_order) == len(kv_cache_shape) + 1 @@ -236,13 +233,10 @@ def allocate_uniform_kv_caches( # prepend a num_layers dimension into the shape kv_cache_shape = (num_layers,) + kv_cache_shape - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True - ) - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=True + ) + assert len(kv_cache_stride_order) == len(kv_cache_shape) kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)