Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions benchmarks/attention_benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,8 @@ def _create_kv_cache(
)

# Get the stride order for custom memory layout
try:
stride_order = backend_class.get_kv_cache_stride_order()
assert len(stride_order) == len(cache_shape)
except (AttributeError, NotImplementedError):
stride_order = tuple(range(len(cache_shape)))
stride_order = backend_class.get_kv_cache_stride_order()
assert len(stride_order) == len(cache_shape)
Comment on lines +336 to +337
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is correct given the new default implementation in AttentionBackend.get_kv_cache_stride_order. This simplifies the code and removes unnecessary error handling.


# Permute shape to physical layout order
physical_shape = tuple(cache_shape[i] for i in stride_order)
Expand Down
6 changes: 2 additions & 4 deletions tests/compile/passes/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,8 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
Comment on lines +111 to +112
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is appropriate as get_kv_cache_stride_order now provides a default implementation, ensuring that an AttributeError or NotImplementedError will no longer be raised. This improves code readability and reduces boilerplate.


kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
Expand Down
9 changes: 3 additions & 6 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,9 @@ def __post_init__(self):
# prepend layers dimension
_MOCK_NUM_LAYERS = 80
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
Comment on lines +345 to +347
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The try/except block is no longer necessary due to the default implementation of get_kv_cache_stride_order. This change correctly streamlines the code.


# In case of cross layers permute kv_cache_shape according to
# stride_order to retrieve physical position of block_size
Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ def get_kv_cache_stride_order(
ordering of dimensions is
[num_blocks, num_heads, 2, block_size, head_size].

If this function is unimplemented / raises NotImplementedError,
the physical layout of the KV cache will match the logical shape.
The default implementation returns the identity permutation (physical
layout matches the logical shape) for the standard 5-dimensional
KV cache shape. Backends that need a custom memory layout should
override this method.
Comment on lines +101 to +104
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The updated docstring accurately reflects the new default behavior of get_kv_cache_stride_order, which now returns the identity permutation instead of raising NotImplementedError. This clarifies the function's contract for future developers.


Args:
include_num_layers_dimension: if True, includes an additional
Expand All @@ -115,7 +117,11 @@ def get_kv_cache_stride_order(
Returns:
A tuple of ints which is a permutation of range(len(shape)).
"""
raise NotImplementedError
# Standard KV cache has 5 dimensions.
num_dims = 5
if include_num_layers_dimension:
num_dims += 1
return tuple(range(num_dims))
Comment on lines +120 to +124
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new default implementation for get_kv_cache_stride_order is a significant improvement. It centralizes the default behavior, reducing boilerplate code across multiple call sites and making the API more robust. The logic for handling include_num_layers_dimension is also correctly implemented.


@classmethod
def full_cls_name(cls) -> tuple[str, str]:
Expand Down
11 changes: 4 additions & 7 deletions vllm/v1/kv_offload/worker/cpu_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,10 @@ def __init__(
assert gpu_shape[0] == 2
split_k_and_v = True

try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=has_layers_dim
)
assert len(kv_cache_stride_order) == len(gpu_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(gpu_shape)))
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=has_layers_dim
)
assert len(kv_cache_stride_order) == len(gpu_shape)
Comment on lines +262 to +265
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is correct, as the get_kv_cache_stride_order method now provides a default implementation, eliminating the possibility of AttributeError or NotImplementedError. This simplifies the code and improves maintainability.


# permute test_shape according to stride_order
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
Expand Down
8 changes: 2 additions & 6 deletions vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,8 @@ def _reshape_kv_cache(
kv_cache_spec.head_size,
)

# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
Comment on lines +102 to +103
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the FIXME comment and the try/except block is appropriate. The new default implementation of get_kv_cache_stride_order resolves the issue that the FIXME comment was addressing, leading to cleaner and more reliable code.


kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
Expand Down
7 changes: 2 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5877,11 +5877,8 @@ def _reshape_kv_cache_tensors(
cache_dtype_str=self.cache_config.cache_dtype,
)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
Comment on lines +5880 to +5881
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is a correct simplification. With get_kv_cache_stride_order now having a default implementation, these error handling constructs are no longer necessary, making the code cleaner and more direct.

# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
Expand Down
22 changes: 8 additions & 14 deletions vllm/v1/worker/kv_connector_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def use_uniform_kv_cache(
have the same page size.
2. A KV connector is configured, and the KV connector instance prefers
to use this layout (prefer_cross_layer_blocks() returns True)
2. The flash attention backend supports this layout
3. The attention backend supports this layout
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change from 'flash attention backend' to 'attention backend' is a good generalization, as the get_kv_cache_stride_order method is now part of the base AttentionBackend and not specific to FlashAttention. This improves the accuracy of the docstring.

(get_kv_cache_stride_order(True) includes a placement for a
num_layers dimension)

Expand Down Expand Up @@ -166,12 +166,9 @@ def use_uniform_kv_cache(
cache_dtype_str=cache_dtype,
)

try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
except (AttributeError, NotImplementedError):
return False
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
Comment on lines +169 to +171
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is correct. The get_kv_cache_stride_order method now guarantees a return value, making the error handling redundant and simplifying the logic.


# check that attention backend include a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
Expand Down Expand Up @@ -236,13 +233,10 @@ def allocate_uniform_kv_caches(
# prepend a num_layers dimension into the shape
kv_cache_shape = (num_layers,) + kv_cache_shape

try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
Comment on lines +236 to +239
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is appropriate. The get_kv_cache_stride_order method now provides a default implementation, so the error handling for AttributeError or NotImplementedError is no longer needed. This makes the code more concise.


kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)

Expand Down
Loading