Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
Expand All @@ -37,6 +40,37 @@
from vllm_ascend.worker.npu_input_batch import InputBatch


def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return

connector = get_kv_transfer_group()

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return

connector = get_kv_transfer_group()

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)


class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True

Expand Down Expand Up @@ -510,6 +544,7 @@ def unified_ascend_attention_with_output(
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
Expand All @@ -522,6 +557,7 @@ def unified_ascend_attention_with_output(
attn_metadata,
output,
trace_flag=False)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return


Expand Down
Loading