Skip to content

refactor(attention): add default get_kv_cache_stride_order implementation#34742

Open
timon0305 wants to merge 1 commit intovllm-project:mainfrom
timon0305:refactor-kv-cache-stride-order
Open

refactor(attention): add default get_kv_cache_stride_order implementation#34742
timon0305 wants to merge 1 commit intovllm-project:mainfrom
timon0305:refactor-kv-cache-stride-order

Conversation

@timon0305
Copy link
Copy Markdown

Purpose

Provide a default implementation for AttentionBackend.get_kv_cache_stride_order() that returns the identity permutation, eliminating the need for try/except (AttributeError, NotImplementedError) fallbacks at every call site.

Previously, the base method raised NotImplementedError, forcing 7 call sites across the codebase to wrap calls in try/except blocks with manual fallback to tuple(range(len(shape))). This was flagged in #33951 and the related FIXME comment in attn_utils.py.

The default returns the identity permutation for the standard 5-dimensional KV cache shape (or 6 with layers dimension). Backends that need custom memory layouts (FlashAttention, FlashInfer, Triton, MLA, etc.) already override this method and are unaffected.

Fixes #33951

Changes

  • vllm/v1/attention/backend.py: Replace raise NotImplementedError with a default identity permutation
  • vllm/v1/worker/gpu_model_runner.py: Remove try/except fallback
  • vllm/v1/worker/kv_connector_model_runner_mixin.py: Remove 2 try/except fallbacks, fix numbering in docstring
  • vllm/v1/worker/gpu/attn_utils.py: Remove try/except fallback and FIXME comment
  • vllm/v1/kv_offload/worker/cpu_gpu.py: Remove try/except fallback
  • vllm/distributed/kv_transfer/kv_connector/utils.py: Remove try/except fallback
  • benchmarks/attention_benchmarks/runner.py: Remove try/except fallback

Test Plan

python -c "
from vllm.v1.attention.backend import AttentionBackend
assert AttentionBackend.get_kv_cache_stride_order() == (0, 1, 2, 3, 4)
assert AttentionBackend.get_kv_cache_stride_order(True) == (0, 1, 2, 3, 4, 5)
from vllm.v1.attention.backends.cpu_attn import CPUAttentionBackend
shape = CPUAttentionBackend.get_kv_cache_shape(10, 16, 8, 64)
order = CPUAttentionBackend.get_kv_cache_stride_order()
assert len(shape) == len(order)
print('All assertions passed')
"

Test Result

Default (no layers): (0, 1, 2, 3, 4)
Default (with layers): (0, 1, 2, 3, 4, 5)
Triton (no layers): (0, 1, 2, 3, 4)
Triton (with layers): (1, 0, 2, 3, 4, 5)
CPU (no layers): (0, 1, 2, 3, 4)
CPU shape: (2, 10, 8, 16, 64), ndim=5
All tests passed!

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update.

…tion

Fixes vllm-project#33951

Signed-off-by: timon0305 <timon0305@outlook.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request refactors the AttentionBackend.get_kv_cache_stride_order method to provide a default implementation, eliminating the need for try/except blocks at call sites. This change simplifies the codebase and improves clarity by centralizing the default behavior. The modifications correctly remove redundant error handling and align with the new default implementation. The changes are well-tested and address the identified issue effectively.

Comment on lines +336 to +337
stride_order = backend_class.get_kv_cache_stride_order()
assert len(stride_order) == len(cache_shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is correct given the new default implementation in AttentionBackend.get_kv_cache_stride_order. This simplifies the code and removes unnecessary error handling.

Comment on lines +111 to +112
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is appropriate as get_kv_cache_stride_order now provides a default implementation, ensuring that an AttributeError or NotImplementedError will no longer be raised. This improves code readability and reduces boilerplate.

Comment on lines +345 to +347
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The try/except block is no longer necessary due to the default implementation of get_kv_cache_stride_order. This change correctly streamlines the code.

Comment on lines +101 to +104
The default implementation returns the identity permutation (physical
layout matches the logical shape) for the standard 5-dimensional
KV cache shape. Backends that need a custom memory layout should
override this method.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The updated docstring accurately reflects the new default behavior of get_kv_cache_stride_order, which now returns the identity permutation instead of raising NotImplementedError. This clarifies the function's contract for future developers.

Comment on lines +120 to +124
# Standard KV cache has 5 dimensions.
num_dims = 5
if include_num_layers_dimension:
num_dims += 1
return tuple(range(num_dims))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new default implementation for get_kv_cache_stride_order is a significant improvement. It centralizes the default behavior, reducing boilerplate code across multiple call sites and making the API more robust. The logic for handling include_num_layers_dimension is also correctly implemented.

Comment on lines +102 to +103
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the FIXME comment and the try/except block is appropriate. The new default implementation of get_kv_cache_stride_order resolves the issue that the FIXME comment was addressing, leading to cleaner and more reliable code.

Comment on lines +5880 to +5881
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is a correct simplification. With get_kv_cache_stride_order now having a default implementation, these error handling constructs are no longer necessary, making the code cleaner and more direct.

2. A KV connector is configured, and the KV connector instance prefers
to use this layout (prefer_cross_layer_blocks() returns True)
2. The flash attention backend supports this layout
3. The attention backend supports this layout
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change from 'flash attention backend' to 'attention backend' is a good generalization, as the get_kv_cache_stride_order method is now part of the base AttentionBackend and not specific to FlashAttention. This improves the accuracy of the docstring.

Comment on lines +169 to +171
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is correct. The get_kv_cache_stride_order method now guarantees a return value, making the error handling redundant and simplifying the logic.

Comment on lines +236 to +239
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the try/except block is appropriate. The get_kv_cache_stride_order method now provides a default implementation, so the error handling for AttributeError or NotImplementedError is no longer needed. This makes the code more concise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector performance Performance-related issues v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Implement get_kv_cache_stride_order for all classes

1 participant