-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[KVConnector][Core] Support cross-layer KV blocks #27743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,7 +38,7 @@ | |
| import enum | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Callable, Iterable | ||
| from typing import TYPE_CHECKING, Any, Literal, Optional | ||
| from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -47,7 +47,7 @@ | |
| from vllm.v1.outputs import KVConnectorOutput | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.attention.backends.abstract import AttentionMetadata | ||
| from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata | ||
| from vllm.config import VllmConfig | ||
| from vllm.distributed.kv_events import KVCacheEvent | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( | ||
|
|
@@ -142,6 +142,18 @@ class KVConnectorMetadata(ABC): # noqa: B024 | |
|
|
||
|
|
||
| class KVConnectorBase_V1(ABC): | ||
| """ | ||
| Base class for KV connectors. | ||
|
|
||
| Attributes: | ||
| prefer_cross_layer_blocks (bool): Indicates whether this connector | ||
| prefers KV blocks that hold KV data for all layers (for speeding | ||
| up KV data transfers). | ||
| Defaults to False. | ||
| """ | ||
|
|
||
| prefer_cross_layer_blocks: ClassVar[bool] = False | ||
|
|
||
| def __init__( | ||
| self, | ||
| vllm_config: "VllmConfig", | ||
|
|
@@ -226,6 +238,23 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
| """ | ||
| return | ||
|
|
||
| def register_cross_layers_kv_cache( | ||
| self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] | ||
| ): | ||
|
Comment on lines
+241
to
+243
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to pass If so, can we also merge this function with the existing
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The single cross-layers kv-cache tensor always matches a single attention backend. This new function may only be called if a connector set
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can't agree. This layout can be extended to full + swa / full + mamba cases in the future. And don't want to change connector API when we start to work on that.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After a second thought, I feel more comfortable with the current API when supporting multiple attention backends. To use cross-layer KV blocks with multiple attention backends, we need to let all backends to make agreement on the underlying stride and it's OK to pass one attention backend that can represent all backends. And the logic for this agreement is needed for all connectors so I prefer to do it in model runner instead of doing it inside each connector.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Dumb question: If the goal is to have the layer name, why don't we just pass the layer name to the connector? Or is it impossible?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not sure I understand your question. |
||
| """ | ||
| Initialize with a single KV cache tensor used by all layers. | ||
| The first dimension should be num_layers. | ||
| This function will only be called for models with uniform layers, | ||
| and only if the prefers_cross_layer_blocks is set to True. | ||
| Only one of the functions | ||
| {register_kv_caches, register_cross_layers_kv_cache} will be called. | ||
|
|
||
| Args: | ||
| kv_cache: a cross-layers kv cache tensor | ||
| attn_backend: The attention backend that corresponds to all layers | ||
| """ | ||
| return | ||
|
|
||
| def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): | ||
| """ | ||
| Set the xPU-specific ops for copying KV between host and device. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick question: what will happen if the connector implementation sets it to true? Will that impact vLLM's initialization process?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessarily.
The conditions that effect whether this cross-layers layout will be used is documented in the
use_uniform_kv_cachefunction: