-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
refactor(attention): add default get_kv_cache_stride_order implementation #34742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,10 +108,8 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: | |
| kv_cache_shape = attn_backend.get_kv_cache_shape( | ||
| num_blocks, self.block_size, self.num_kv_heads, self.head_size | ||
| ) | ||
| try: | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() | ||
| except (AttributeError, NotImplementedError): | ||
| kv_cache_stride_order = tuple(range(len(kv_cache_shape))) | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() | ||
| assert len(kv_cache_stride_order) == len(kv_cache_shape) | ||
|
Comment on lines
+111
to
+112
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) | ||
| inv_order = [ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -342,12 +342,9 @@ def __post_init__(self): | |
| # prepend layers dimension | ||
| _MOCK_NUM_LAYERS = 80 | ||
| kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape | ||
| try: | ||
| kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( | ||
| include_num_layers_dimension=self._cross_layers_blocks | ||
| ) | ||
| except (AttributeError, NotImplementedError): | ||
| kv_cache_stride_order = tuple(range(len(self.tensor_shape))) | ||
| kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( | ||
| include_num_layers_dimension=self._cross_layers_blocks | ||
| ) | ||
|
Comment on lines
+345
to
+347
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # In case of cross layers permute kv_cache_shape according to | ||
| # stride_order to retrieve physical position of block_size | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,8 +98,10 @@ def get_kv_cache_stride_order( | |
| ordering of dimensions is | ||
| [num_blocks, num_heads, 2, block_size, head_size]. | ||
|
|
||
| If this function is unimplemented / raises NotImplementedError, | ||
| the physical layout of the KV cache will match the logical shape. | ||
| The default implementation returns the identity permutation (physical | ||
| layout matches the logical shape) for the standard 5-dimensional | ||
| KV cache shape. Backends that need a custom memory layout should | ||
| override this method. | ||
|
Comment on lines
+101
to
+104
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| Args: | ||
| include_num_layers_dimension: if True, includes an additional | ||
|
|
@@ -115,7 +117,11 @@ def get_kv_cache_stride_order( | |
| Returns: | ||
| A tuple of ints which is a permutation of range(len(shape)). | ||
| """ | ||
| raise NotImplementedError | ||
| # Standard KV cache has 5 dimensions. | ||
| num_dims = 5 | ||
| if include_num_layers_dimension: | ||
| num_dims += 1 | ||
| return tuple(range(num_dims)) | ||
|
Comment on lines
+120
to
+124
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| @classmethod | ||
| def full_cls_name(cls) -> tuple[str, str]: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -259,13 +259,10 @@ def __init__( | |
| assert gpu_shape[0] == 2 | ||
| split_k_and_v = True | ||
|
|
||
| try: | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( | ||
| include_num_layers_dimension=has_layers_dim | ||
| ) | ||
| assert len(kv_cache_stride_order) == len(gpu_shape) | ||
| except (AttributeError, NotImplementedError): | ||
| kv_cache_stride_order = tuple(range(len(gpu_shape))) | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( | ||
| include_num_layers_dimension=has_layers_dim | ||
| ) | ||
| assert len(kv_cache_stride_order) == len(gpu_shape) | ||
|
Comment on lines
+262
to
+265
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # permute test_shape according to stride_order | ||
| test_shape = tuple(test_shape[i] for i in kv_cache_stride_order) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,12 +99,8 @@ def _reshape_kv_cache( | |
| kv_cache_spec.head_size, | ||
| ) | ||
|
|
||
| # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. | ||
| try: | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() | ||
| assert len(kv_cache_stride_order) == len(kv_cache_shape) | ||
| except (AttributeError, NotImplementedError): | ||
| kv_cache_stride_order = tuple(range(len(kv_cache_shape))) | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() | ||
| assert len(kv_cache_stride_order) == len(kv_cache_shape) | ||
|
Comment on lines
+102
to
+103
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) | ||
| inv_order = [ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5877,11 +5877,8 @@ def _reshape_kv_cache_tensors( | |
| cache_dtype_str=self.cache_config.cache_dtype, | ||
| ) | ||
| dtype = kv_cache_spec.dtype | ||
| try: | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() | ||
| assert len(kv_cache_stride_order) == len(kv_cache_shape) | ||
| except (AttributeError, NotImplementedError): | ||
| kv_cache_stride_order = tuple(range(len(kv_cache_shape))) | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() | ||
| assert len(kv_cache_stride_order) == len(kv_cache_shape) | ||
|
Comment on lines
+5880
to
+5881
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| # The allocation respects the backend-defined stride order | ||
| # to ensure the semantic remains consistent for each | ||
| # backend. We first obtain the generic kv cache shape and | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,7 +127,7 @@ def use_uniform_kv_cache( | |
| have the same page size. | ||
| 2. A KV connector is configured, and the KV connector instance prefers | ||
| to use this layout (prefer_cross_layer_blocks() returns True) | ||
| 2. The flash attention backend supports this layout | ||
| 3. The attention backend supports this layout | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| (get_kv_cache_stride_order(True) includes a placement for a | ||
| num_layers dimension) | ||
|
|
||
|
|
@@ -166,12 +166,9 @@ def use_uniform_kv_cache( | |
| cache_dtype_str=cache_dtype, | ||
| ) | ||
|
|
||
| try: | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( | ||
| include_num_layers_dimension=True | ||
| ) | ||
| except (AttributeError, NotImplementedError): | ||
| return False | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( | ||
| include_num_layers_dimension=True | ||
| ) | ||
|
Comment on lines
+169
to
+171
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # check that attention backend include a layers dimension | ||
| return len(kv_cache_stride_order) == len(kv_cache_shape) + 1 | ||
|
|
@@ -236,13 +233,10 @@ def allocate_uniform_kv_caches( | |
| # prepend a num_layers dimension into the shape | ||
| kv_cache_shape = (num_layers,) + kv_cache_shape | ||
|
|
||
| try: | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( | ||
| include_num_layers_dimension=True | ||
| ) | ||
| assert len(kv_cache_stride_order) == len(kv_cache_shape) | ||
| except (AttributeError, NotImplementedError): | ||
| kv_cache_stride_order = tuple(range(len(kv_cache_shape))) | ||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( | ||
| include_num_layers_dimension=True | ||
| ) | ||
| assert len(kv_cache_stride_order) == len(kv_cache_shape) | ||
|
Comment on lines
+236
to
+239
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removal of the
try/exceptblock is correct given the new default implementation inAttentionBackend.get_kv_cache_stride_order. This simplifies the code and removes unnecessary error handling.