From 117a10988232a8db32b3d1cba3b184fe8d5a12f6 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Sun, 27 Apr 2025 17:32:17 +0800 Subject: [PATCH 01/21] feat: impl the connector based on the llmdatadist for v1 Signed-off-by: Jade Zheng --- .../disagg_prefill_proxy_server.py | 85 ++ .../disaggregated_prefill_multi_prefill.sh | 110 +++ .../disaggregated-prefill-v1/send_request.sh | 23 + vllm_ascend/distributed/__init__.py | 5 + .../distributed/llmdatadist_connector_v1.py | 919 ++++++++++++++++++ vllm_ascend/envs.py | 3 + vllm_ascend/worker/model_runner_v1.py | 7 + 7 files changed, 1152 insertions(+) create mode 100644 examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py create mode 100644 examples/disaggregated-prefill-v1/disaggregated_prefill_multi_prefill.sh create mode 100644 examples/disaggregated-prefill-v1/send_request.sh create mode 100644 vllm_ascend/distributed/llmdatadist_connector_v1.py diff --git a/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py b/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py new file mode 100644 index 00000000000..41ad2b61854 --- /dev/null +++ b/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import os + +import aiohttp +from quart import Quart, make_response, request + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + +PREFILL_ENDPOINT = "localhost:8100" +DECODE_ENDPOINT = "localhost:8200" + + +async def forward_request(url, data, headers: dict): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers.update({ + "Authorization": + f"Bearer {os.environ.get('OPENAI_API_KEY')}", + }) + + async with session.post(url=url, json=data, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + + +@app.route("/v1/completions", methods=["POST"]) +async def handle_request(): + try: + original_request_data = await request.get_json() + print(f"{request.headers.get('X-Request-ID')=}") + + prefill_request = original_request_data.copy() + # Change max_tokens = 1 to let it only do prefill + prefill_request["max_tokens"] = 1 + + # Finish prefill + async for prefill_result in forward_request( + f"http://{PREFILL_ENDPOINT}/v1/completions", + prefill_request, + headers={ + "X-Request-ID": request.headers.get("X-Request-ID"), + }, + ): + # Print the prefill result + print(f"===== Prefill result =====") + print(prefill_result.decode("utf-8")) + print("==========================") + response = json.loads(prefill_result.decode("utf-8")) + continue + + # Get the prefill result token, and add it to the decoding request + decode_request = original_request_data.copy() + for idx, choices in enumerate(response.get("choices")): + decode_request["prompt"][idx] += choices.get("text") + + # Return the decoding result + generator = forward_request( + f"http://{DECODE_ENDPOINT}/v1/completions", + decode_request, + headers={ + "X-Request-ID": request.headers.get("X-Request-ID"), + }, + ) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == "__main__": + app.run(port=8000) diff --git a/examples/disaggregated-prefill-v1/disaggregated_prefill_multi_prefill.sh b/examples/disaggregated-prefill-v1/disaggregated_prefill_multi_prefill.sh new file mode 100644 index 00000000000..6b7d31ff7a9 --- /dev/null +++ b/examples/disaggregated-prefill-v1/disaggregated_prefill_multi_prefill.sh @@ -0,0 +1,110 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling We will +# launch 2 vllm instances (1 for prefill and 1 for decode), and then transfer +# the KV cache between them. + +set -xe + +current_dir=$(dirname "$0") + +# vLLM Environment configuration +export VLLM_USE_V1=1 + +# vLLM-Ascend Environment configuration +export GLOBAL_RANKTABLE="${current_dir}/global_ranktable.json" +# The following environment variables are required for LLMDataDist. +export PROMPT_DEVICE_ID=0,1,2,3 +export DECODE_DEVICE_ID=4,5,6,7 +export TENSOR_PARALLEL_SIZE=$(($(echo $PROMPT_DEVICE_ID | grep -o ',' | wc -l) + 1)) + +# Model Configuration +export MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + +# Generate the global rank table +if [ ! -f "${GLOBAL_RANKTABLE}" ]; then + echo "Generating global rank table..." + # TODO(jianzs): Impl a tool to generate the global rank table automatically +else + echo "Global rank table already exists." +fi + +echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧" +sleep 1 + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'cleanup' INT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +# install quart first -- required for disagg prefill proxy serve +if python3 -c "import quart" &>/dev/null; then + echo "Quart is already installed." +else + echo "Quart is not installed. Installing..." + python3 -m pip install quart +fi + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +ASCEND_RT_VISIBLE_DEVICES=${PROMPT_DEVICE_ID} vllm serve ${MODEL_NAME} \ + --port 8100 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --enforce-eager \ + --no-enable-prefix-caching \ + --tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \ + --kv-transfer-config \ + '{ + "kv_connector": "AscendHcclConnectorV1", + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_rank": 0, + "kv_parallel_size": 2, + "kv_connector_extra_config": { + "local_server_id": "server-0" + } + }' & + +ASCEND_RT_VISIBLE_DEVICES=${DECODE_DEVICE_ID} vllm serve ${MODEL_NAME} \ + --port 8200 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --enforce-eager \ + --no-enable-prefix-caching \ + --tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \ + --kv-transfer-config \ + '{ + "kv_connector": "AscendHcclConnectorV1", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_rank": 1, + "kv_parallel_size": 2, + "kv_connector_extra_config": { + "local_server_id": "server-1" + } + }' & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +echo "🚧🚧 Warning: server started 🚧🚧" + +python3 disagg_prefill_proxy_server.py diff --git a/examples/disaggregated-prefill-v1/send_request.sh b/examples/disaggregated-prefill-v1/send_request.sh new file mode 100644 index 00000000000..a5449bb4f6d --- /dev/null +++ b/examples/disaggregated-prefill-v1/send_request.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# Make sure the model is same as the one used in the server +MODEL_NAME="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" +REQUEST_ID=request$RANDOM + +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -H "X-Request-ID: ${REQUEST_ID}" \ + -d '{ + "ignore_eos": false, + "stream": false, + "stop": "None", + "temperature": 0.5, + "top_k": -1, + "top_p": 1, + "model": "'${MODEL_NAME}'", + "prompt": [ + "In 2020, who won the world series?", + "In 2019, Who won the world series?" + ], + "max_tokens": 40 + }' diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 88c2f2199b2..325a718f0bf 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -25,3 +25,8 @@ KVConnectorFactory.register_connector( "AscendSimpleConnector", "vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector") + +KVConnectorFactory.register_connector( + "AscendHcclConnectorV1", + "vllm_ascend.distributed.llmdatadist_connector_v1", + "LLMDataDistConnectorV1") diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py new file mode 100644 index 00000000000..9f8494b758b --- /dev/null +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -0,0 +1,919 @@ +import enum +import hashlib +import json +import struct +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Tuple + +import requests +import torch +import torch_npu +import torchair +from vllm.distributed import get_tensor_model_parallel_rank, get_world_group +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.config import VllmConfig, KVTransferConfig + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +import llm_datadist # type: ignore +from llm_datadist import LLMException, LLMStatusCode + +import vllm_ascend.envs as envs +from vllm_ascend.attention.mla_v1 import AscendMLAMetadata + +logger = init_logger(__name__) + +TORCH_DTYPE_TO_NPU_DTYPE = { + torch.half: llm_datadist.DataType.DT_FLOAT16, + torch.float16: llm_datadist.DataType.DT_FLOAT16, + torch.bfloat16: llm_datadist.DataType.DT_BF16, + torch.float: llm_datadist.DataType.DT_FLOAT, + torch.float32: llm_datadist.DataType.DT_FLOAT, + torch.int8: llm_datadist.DataType.DT_INT8, + torch.int64: llm_datadist.DataType.DT_INT64, + torch.int32: llm_datadist.DataType.DT_INT32, +} + +GLOBAL_RANKTABLE = envs.GLOBAL_RANKTABLE + + +class ServerRole(enum.Enum): + Router = "router" + Prefill = "prefill" + Decode = "decode" + + +@dataclass +class DeviceInfo: + device_id: int + device_ip: str + dp_rank: int + tp_rank: int + cluster_id: int + + +@dataclass +class ServerInfo: + server_id: str + server_ip: str + role: ServerRole + devices: List[DeviceInfo] + + def get_device(self, tp_rank: int, dp_rank: int) -> DeviceInfo: + for device in self.devices: + if device.tp_rank == tp_rank and device.dp_rank == dp_rank: + return device + return None + + +def get_servers_from_ranktable(ranktable_path: str, prefill_tp: int, + decode_tp: int) -> List[ServerInfo]: + cluster_index = 0 + + def parse_server_group(group, role: ServerRole, + tp_size: int) -> List[ServerInfo]: + nonlocal cluster_index + + server_infos: List[ServerInfo] = [] + for server in group.get("server_list", []): + server_ip = server.get("server_ip") + server_id = server.get("server_id") + + device_infos: List[DeviceInfo] = [] + for device in server.get("device", []): + device_id = int(device.get("device_id").strip()) + device_ip = device.get("device_ip") + device_infos.append( + DeviceInfo( + device_id=int(device_id), + device_ip=device_ip, + dp_rank=-1, + tp_rank=-1, + cluster_id=-1, + )) + + # Assign dp, tp rank and unique cluster_id to all devices in this + # server + device_infos = sorted(device_infos, key=lambda x: x.device_id) + for i, device_info in enumerate(device_infos): + device_info.dp_rank = i // tp_size + device_info.tp_rank = i % tp_size + device_info.cluster_id = cluster_index + cluster_index += 1 + + server_infos.append( + ServerInfo(server_id=server_id, + server_ip=server_ip, + role=role, + devices=device_infos)) + return server_infos + + with open(ranktable_path, "r") as file: + rank_table = json.load(file) + + for group in rank_table.get("server_group_list", []): + group_id = group.get("group_id", None) + if group_id == "0": # router + router_servers = parse_server_group(group, + ServerRole.Router, + tp_size=-1) + assert len( + router_servers + ) == 1, f"Must have only one server in group 0, but got {len(router_servers)}" + router = router_servers[0] + elif group_id == "1": # prefill + prefill_servers = parse_server_group(group, ServerRole.Prefill, + prefill_tp) + elif group_id == "2": # decode + decode_servers = parse_server_group(group, ServerRole.Decode, + decode_tp) + else: + raise ValueError( + f"Unknown group_id {group_id} in server_group_list") + return [router] + prefill_servers + decode_servers + + +class ClusterInfo: + + def __init__(self, vllm_config: "VllmConfig") -> None: + # If tensor parallel (tp) and data parallel (dp) sizes are not found in + # the extra config, use the parallel configuration of the current + # instance as default. This is useful when the prefill and decode nodes + # share the same parallel configuration. + self._tp_size = vllm_config.parallel_config.tensor_parallel_size + self._dp_size = vllm_config.parallel_config.data_parallel_size + + kv_transfer_config: "KVTransferConfig" = vllm_config.kv_transfer_config + self._prefill_parallel_config: dict[ + str, + Any] = kv_transfer_config.get_from_extra_config("prefill", {}) + self._decode_parallel_config: dict[ + str, Any] = kv_transfer_config.get_from_extra_config("decode", {}) + + self._servers: List[ServerInfo] = get_servers_from_ranktable( + GLOBAL_RANKTABLE, self.prefill_tp, self.decode_tp) + + def get_device(self, server_id: str, dp_rank: int, + tp_rank: int) -> DeviceInfo: + for server in self._servers: + if server.server_id != server_id: + continue + return server.get_device(tp_rank, dp_rank) + return None + + def get_cluster_id(self, server_id: str, dp_rank: int, + tp_rank: int) -> int: + device_info = self.get_device(server_id, dp_rank, tp_rank) + if device_info is None: + raise ValueError( + f"Could not find device({server_id},{dp_rank},{tp_rank}) in cluster info." + ) + return device_info.cluster_id + + def get_servers_by_role(self, role: ServerRole) -> List[ServerInfo]: + return [server for server in self._servers if server.role == role] + + @property + def router_endpoint(self): + for server in self._servers: + if server.role == ServerRole.Router: + return f"http://{server.server_ip}:9000" + raise ValueError("Router endpoint not found") + + @property + def prefill_dp(self): + candidate_keys = ["data_parallel_size", "dp_size", "dp"] + return int( + self._get_first_matching_value(self._prefill_parallel_config, + candidate_keys, self._dp_size)) + + @property + def prefill_tp(self): + candidate_keys = ["tensor_parallel_size", "tp_size", "tp"] + return int( + self._get_first_matching_value(self._prefill_parallel_config, + candidate_keys, self._tp_size)) + + @property + def decode_dp(self): + candidate_keys = ["data_parallel_size", "dp_size", "dp"] + return int( + self._get_first_matching_value(self._decode_parallel_config, + candidate_keys, self._dp_size)) + + @property + def decode_tp(self): + candidate_keys = ["tensor_parallel_size", "tp_size", "tp"] + return int( + self._get_first_matching_value(self._decode_parallel_config, + candidate_keys, self._tp_size)) + + def _get_first_matching_value(self, config_dict: dict, + candidate_keys: List[str], + default: Any) -> Any: + for key in candidate_keys: + if key in config_dict: + return config_dict[key] + return default + + +_CLUSTER_INFO: "ClusterInfo" = None + + +def init_cluster_info(vllm_config: "VllmConfig") -> None: + global _CLUSTER_INFO + if _CLUSTER_INFO is not None: + raise ValueError("ClusterInfo is already initialized.") + _CLUSTER_INFO = ClusterInfo(vllm_config) + + +def get_cluster_info() -> "ClusterInfo": + global _CLUSTER_INFO + if _CLUSTER_INFO is None: + raise ValueError("ClusterInfo is not initialized.") + return _CLUSTER_INFO + + +def report_prefill_info(meta_server_url, prefill_info): + response = requests.post(f"{meta_server_url}/put", json=prefill_info) + if response.status_code != 200: + logger.error( + f"put_prefill_info failed status_code: {response.status_code}, response: {response.text}" + ) + + +def fetch_prefill_info(meta_server_url, request_ids): + response = requests.get(f"{meta_server_url}/get", json=request_ids) + if response.status_code != 200: + logger.error( + f"get_prefill_info failed status_code: {response.status_code}, response: {response.text}" + ) + return None + return response.json() + + +class KVTransferEngine: + + def __init__(self, role: llm_datadist.LLMRole, local_rank: int, + dp_rank: int, tp_rank: int, local_server_id: str) -> None: + self.role = role + self.local_rank = local_rank + self.tp_rank = tp_rank + self.cluster_info = get_cluster_info() + + local_device_info = self.cluster_info.get_device( + local_server_id, dp_rank, tp_rank) + assert local_device_info is not None, \ + f"Could not find local device from cluster info." + + self.cluster_id = local_device_info.cluster_id + self.local_device_ip = local_device_info.device_ip + self.datadist_engine = llm_datadist.LLMDataDist( + self.role, self.cluster_id) + + def prepare_data_dist(self): + options = { + "llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME, + } + if self.role == llm_datadist.LLMRole.PROMPT: + # TODO: This represents the maximum size of the mbuf for the llm + # datadist. We need to find an appropriate value to minimize memory + # waste. + # options["ge.flowGraphMemMaxSize"] = "1024" # MB + options["ge.exec.deviceId"] = str(self.local_rank) + options["llm.listenIpInfo"] = f"{self.local_device_ip}:26000" + else: + options["ge.exec.deviceId"] = str(self.local_rank) + self.datadist_engine.init(options) + self.kv_transfer = self.datadist_engine.kv_cache_manager + + def make_cluster(self, prefill_ip, cluster_id=-1): + cluster = llm_datadist.LLMClusterInfo() + cluster.remote_cluster_id = cluster_id + cluster.append_local_ip_info(self.local_device_ip, 0) + cluster.append_remote_ip_info(prefill_ip, 26000) + logger.info(f"link decode ip {self.local_device_ip} -> {prefill_ip}") + return cluster + + def make_clusters(self): + clusters = [] + # Find all devices from prefill servers this rank need to connect + for server in self.cluster_info.get_servers_by_role( + ServerRole.Prefill): + for device in server.devices: + target_tp_rank = self.tp_rank % min( + self.cluster_info.prefill_tp, self.cluster_info.decode_tp) + if target_tp_rank == device.tp_rank: + cluster = self.make_cluster(device.device_ip, + device.cluster_id) + clusters.append(cluster) + return clusters + + +@dataclass +class ReqMeta: + # Request ID, unique for each request + request_id: str + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Is store or load + is_store: bool + + @staticmethod + def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], + block_size: int, is_store: bool) -> "ReqMeta": + token_ids_tensor = torch.tensor(token_ids) + valid_num_tokens = len(token_ids) + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = block_offsets.reshape( + (1, block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + return ReqMeta( + request_id=request_id, + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + is_store=is_store, + ) + + +@dataclass +class LLMDataDistConnectorV1Metadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request(self, request_id: str, token_ids: list[int], + block_ids: list[int], block_size: int, + is_store: bool) -> None: + self.requests.append( + ReqMeta.make_meta(request_id, token_ids, block_ids, block_size, + is_store)) + + +class LLMDataDistConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", + role: KVConnectorRole) -> None: + super().__init__(vllm_config=vllm_config, role=role) + + # Used by both scheduler and worker process + kv_transfer_config: "KVTransferConfig" = self._vllm_config.kv_transfer_config + self._block_size = vllm_config.cache_config.block_size + if kv_transfer_config.is_kv_producer: + self.kv_role = llm_datadist.LLMRole.PROMPT + elif kv_transfer_config.is_kv_consumer: + self.kv_role = llm_datadist.LLMRole.DECODER + else: + raise ValueError( + f"The value of kv_role must be either `kv_producer` or `kv_consumer`, but received {kv_transfer_config.kv_role}." + ) + + # Used by scheduler process + self._requests_need_load: dict[str, Request] = {} + + if role == KVConnectorRole.SCHEDULER: + # In the scheduler process, the distributed environment is not + # initialized. As a result, functions like `get_world_group` cannot + # be used. Additionally, the scheduler does not require initializing + # the KVTransferEngine. Therefore, simply return. + return + + # Used by worker process + init_cluster_info(self._vllm_config) + self.cluster_info = get_cluster_info() + + self.local_server_id = kv_transfer_config.get_from_extra_config( + "local_server_id", None) + assert ( + self.local_server_id is not None + ), f"Cannot find `local_server_id` from `kv_transfer_config.kv_connector_extra_config`." + + self.dp_rank = self._vllm_config.parallel_config.data_parallel_rank + self.tp_size = self._vllm_config.parallel_config.tensor_parallel_size + self.tp_rank = get_tensor_model_parallel_rank() + self.num_layers = self._vllm_config.model_config.get_num_layers( + self._vllm_config.parallel_config) + + local_rank = get_world_group().local_rank + self.llm_datadist_engine = KVTransferEngine(self.kv_role, local_rank, + self.dp_rank, self.tp_rank, + self.local_server_id) + self.llm_datadist_engine.prepare_data_dist() + if self.kv_role == llm_datadist.LLMRole.DECODER: + while True: + try: + # Each decoding rank should correspond to each prefilling rank. + clusters = self.llm_datadist_engine.make_clusters() + _, ret = self.llm_datadist_engine.datadist_engine.link_clusters( + clusters, 20000) + logger.info(f"{local_rank} link, ret={ret}") + break + except LLMException as e: + logger.error( + f"Failed to link clusters, local_rank {local_rank}, error: {e}" + ) + time.sleep(1) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + # Note: we recv kv cache by request now, so we do not need by + # layer operations, all recv is done in start_load_kv + + if self.kv_role == llm_datadist.LLMRole.PROMPT: + # In the prefilling node, do not need to load KV cache. + return + + # Get the metadata + metadata = self._get_connector_metadata() + assert isinstance(metadata, LLMDataDistConnectorV1Metadata) + assert metadata is not None, "The connector metadata should not be None." + if len(metadata.requests) == 0: + # No requests to load + return + + attn_metadata = forward_context.attn_metadata + assert attn_metadata is not None, "The attn_metadata should not be None." + + request_ids = [ + self._get_unique_req_id(req.request_id) + for req in metadata.requests if not req.is_store + ] + prefill_infos = fetch_prefill_info(self.cluster_info.router_endpoint, + request_ids) + # If prefill_infos is None, it indicates that get_prefill_info failed. + # Therefore, we need to recalculate the kv cache during the decoding + # phase. If there is a performance issue, we should consider whether + # this is the cause. + if prefill_infos is None: + logger.error( + f"[rank%d][D]: Failed to get prefill info, redo model forwarding.", + torch.distributed.get_rank()) + return None + + kv_cache_layers = [] + for _, attn_layer in forward_context.no_compile_layers.items(): + kv_cache_layer = attn_layer.kv_cache[ + forward_context.virtual_engine] + kv_cache_layers.append(kv_cache_layer) + + is_mla = isinstance(attn_metadata, AscendMLAMetadata) + kv_cache_layer_shape = list(kv_cache_layers[0].shape) + num_heads = int(kv_cache_layer_shape[-2]) + head_dim = int(kv_cache_layer_shape[-1]) + # Load the KV for each request each layer + for request in metadata.requests: + if request.is_store: + continue + # NOTE: slen is the len of kv cache need to load for this request + # in decode, request_len = prefill_prompt_len + 1 + slen = request.token_ids.shape[0] - 1 + cur_slot_mapping = request.slot_mapping[:slen] + + # For the datadist tensor, the first dimension is 1, the reason can + # be found in wait_for_save function + if is_mla: + # [1, slen, num_heads, head_dim] + kv_cache_shape: Tuple[int, + ...] = (1, slen, num_heads, head_dim) + else: + # [1, 2, slen, num_heads, head_dim] + kv_cache_shape = (1, 2, slen, num_heads, head_dim) + + uniq_req_id = self._get_unique_req_id(request.request_id) + dp_rank = prefill_infos[uniq_req_id]["dp_rank"] + server_id = prefill_infos[uniq_req_id]["server_id"] + + # pull kv cache from prefill node by request + kv_hidden_dtype = kv_cache_layers[0].dtype + kv_buffer, pulled_kv_caches = self._create_cache_tensors( + self.num_layers, kv_cache_shape, kv_hidden_dtype) + + target_tp_rank = self.tp_rank % min( + self.cluster_info.prefill_tp, + self.cluster_info.decode_tp, + ) + remote_cluster_id = self.cluster_info.get_cluster_id( + server_id, dp_rank, target_tp_rank) + + # Each request uses the same llm_datadist request_id, which needs to + # be converted into an integer value. + datadist_request_id = string_to_int64_hash(request.request_id) + kv_cache_key = llm_datadist.CacheKey(remote_cluster_id, + datadist_request_id, 1) + self.llm_datadist_engine.kv_transfer.pull_cache( + kv_cache_key, kv_buffer, 0) + + # Check for any transmission failures; we need to redo the + # forwarding to compute the missing states. + if pulled_kv_caches is None: + logger.error( + "[rank%d][D]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", + torch.distributed.get_rank(), + ) + # TODO: break or continue? + break + + for layer_id, kv_cache_layer in enumerate(kv_cache_layers): + pulled_kv_cache = pulled_kv_caches[layer_id] + self._inject_kv_into_layer(kv_cache_layer, pulled_kv_cache, + cur_slot_mapping, is_mla) + + # Release the reference count + self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + # Note: we send kv cache by request now, so we do not need by + # layer operations, all send is done in wait_for_save + + if self.kv_role == llm_datadist.LLMRole.DECODER: + # In the prompt role, we do not need to load KV cache. + return + + forward_context = get_forward_context() + metadata: KVConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, LLMDataDistConnectorV1Metadata) + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.warning( + "In connector.start_load_kv, but the attn_metadata is None") + return + + is_mla = isinstance(attn_metadata, AscendMLAMetadata) + indices = torch.tensor([0], dtype=torch.int64, device="npu") + + prefill_info_input = {} + # kv cache should be transfered by request + for _, request in enumerate(metadata.requests): + if not request.is_store: + continue + + slen = request.token_ids.shape[0] + req_slot_mapping = request.slot_mapping[:slen] + + uniq_req_id = self._get_unique_req_id(request.request_id) + prefill_info_input[uniq_req_id] = { + "dp_rank": self.dp_rank, + "server_id": self.local_server_id, + } + + kv_caches: List[torch.Tensor] = [] + for _, attn_layer in forward_context.no_compile_layers.items(): + kv_cache_layer = attn_layer.kv_cache[ + forward_context.virtual_engine] + kv_cache = self._extract_kv_from_layer(kv_cache_layer, + req_slot_mapping, + is_mla) + kv_caches.append(kv_cache.detach()) + + # Initialize LLMDatadist data structure. Each request uses the same + # llm_datadist request_id, which needs to be converted to an integer + # value. + datadist_request_id = string_to_int64_hash(request.request_id) + kv_cache_keys = [ + llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, + datadist_request_id, 1) + ] + + # If MLA is used, the kv_cache_shape should be (1, slen, num_heads, + # head_dim). Otherwise, it should be (1, 2, slen, num_heads, + # head_dim). The first dimension must be 1, because the following + # `scatter_update_` operation will fail otherwise. The exact reason + # for this limitation is currently unknown. + kv_cache_shape = (1, ) + tuple(kv_caches[0].shape) + kv_hidden_dtype = kv_caches[0].dtype + kv_buffer, pushed_kv_caches = self._create_cache_tensors( + self.num_layers, kv_cache_shape, kv_hidden_dtype, + kv_cache_keys) + for layer_idx, kv_cache in enumerate(kv_caches): + datadist_kv_cache = pushed_kv_caches[layer_idx] + kv_cache = kv_cache.unsqueeze(0) + torch_npu.scatter_update_(datadist_kv_cache, + indices, + kv_cache, + axis=-2) + + # Release reference count + self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer) + + # Report prefill info to meta server + report_prefill_info(self.cluster_info.router_endpoint, + prefill_info_input) + logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) + + def _inject_kv_into_layer( + self, + dst_kv_cache_layer: torch.Tensor, + pulled_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + is_mla: bool, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_blocks, block_size, num_heads, head_dim] + if not using MLA, [num_blocks, block_size, num_heads, head_dim] + otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [1, 2, num_tokens, num_heads, head_dim] if not using MLA, [1, + num_tokens, num_heads, head_dim] otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + """ + # NOTE: The performance of this function is suboptimal. Using + # `torch_npu._npu_reshape_and_cache` or + # `torch_npu._npu_reshape_and_cache_siso` could improve performance + # significantly. However, attempts to use these methods have failed, and + # the root cause remains unclear. The only available information is an + # error log from the ATB log file, which states: + # "ReshapeAndCacheOperation_1 invalid param, setup check fail, error + # code: 13." + + # The pulled KV cache resides in the mbuf memory space and cannot be + # directly copied to the kv_cache_layer. Therefore, it must first be + # copied to a standard torch tensor using `scatter_update_`. + kv_cache = torch.empty_like(pulled_kv_cache) + indices = torch.tensor([0], dtype=torch.int64, device="npu") + torch_npu.scatter_update_(kv_cache, indices, pulled_kv_cache, axis=-2) + # The `wait_for_save` function explains why the first dimension is + # necessary. + kv_cache = kv_cache.squeeze(0) + + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if is_mla: + block_size = dst_kv_cache_layer_shape[1] + num_heads = dst_kv_cache_layer_shape[2] + head_dim = dst_kv_cache_layer_shape[3] + idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size + dst_kv_cache_layer = dst_kv_cache_layer.view( + -1, num_heads, head_dim) + dst_kv_cache_layer[idx_for_copy, ...] = kv_cache + else: + block_size = dst_kv_cache_layer_shape[2] + num_heads = dst_kv_cache_layer_shape[3] + head_dim = dst_kv_cache_layer_shape[4] + idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size + dst_kv_cache_layer = dst_kv_cache_layer.view( + 2, -1, num_heads, head_dim) + dst_kv_cache_layer[:, idx_for_copy, ...] = kv_cache + + def _extract_kv_from_layer( + self, + kv_cache_layer: torch.Tensor, + slot_mapping: torch.Tensor, + is_mla: bool, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is [2, num_blocks, block_size, num_heads, + head_dim] if MLA is not used, and [num_blocks, block_size, num_heads, + head_dim] otherwise. + """ + if is_mla: + num_heads, head_dim = kv_cache_layer.shape[ + 2], kv_cache_layer.shape[3] + return kv_cache_layer.view(-1, num_heads, head_dim)[slot_mapping, + ...] + + num_heads, head_dim = kv_cache_layer.shape[2], kv_cache_layer.shape[3] + return kv_cache_layer.view(2, -1, num_heads, head_dim)[:, slot_mapping, + ...] + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the external KV cache + beyond the num_computed_tokens. + + Args: + request (Request): the request object. num_computed_tokens (int): + the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the external KV cache + beyond what is already computed. + """ + # NOTE: in current v1 scheduler, the num_computed_tokens is aligned with + # the block granularity. And it expects the returned blocks and + # num_computed_tokens to also be aligned with the block granularity. + + # NOTE: only request in waiting queue will come here. we use datadist + # pull cache to do transfer, so we don't align to block_size in prefill, + # we won't have extra new matched tokens; in decode, new request kv + # cache will be transfered from prefill, so num_computed_tokens = 0, and + # extra new matched tokens should be len(request.prompt_token_ids) - 1 + if self.kv_role == llm_datadist.LLMRole.PROMPT: + return 0 + return len(request.prompt_token_ids) - 1 + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + if num_external_tokens > 0: + self._requests_need_load[request.request_id] = request + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. Also, + calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = LLMDataDistConnectorV1Metadata() + + total_need_load = 0 + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id in self._requests_need_load: + meta.add_request( + request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=False, + ) + total_need_load += 1 + else: + # NOTE: here, we set the store and load being exclusive, + # but a single request can have both store and load. + # NOTE(rob): for this debug implementation, we only cache + # the original prompt tokens. + meta.add_request( + request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=True, + ) + + for cached_req in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not cached_req.resumed_from_preemption: + break + if cached_req.req_id in self._requests_need_load: + # NOTE(rob): cached_req_data does not have the full + # list of token ids (only new tokens). So we look it + # up in the actual request object. + request = self._requests_need_load[cached_req.req_id] + total_tokens = len( + cached_req.new_token_ids) + cached_req.num_computed_tokens + token_ids = request.all_token_ids[:total_tokens] + + meta.add_request( + request_id=cached_req.req_id, + token_ids=token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=False, + ) + total_need_load += 1 + + assert total_need_load == len(self._requests_need_load) + self._requests_need_load.clear() + return meta + + def _create_cache_tensors(self, + num_layer: int, + shape: Tuple[int, ...], + dtype: torch.dtype, + cache_keys=[]): + seq_len_dim_index = -2 + len(shape) + cache_desc = llm_datadist.CacheDesc( + num_layer, + shape, + TORCH_DTYPE_TO_NPU_DTYPE[dtype], + seq_len_dim_index=seq_len_dim_index) + # TODO(jianzs): At present, there is no method to determine the + # available space in the mbuf memory. Therefore, we can only attempt to + # handle allocation failures; if the failure is due to insufficient + # space, we pause briefly before retrying until the allocation succeeds. + while True: + try: + cache_buf = self.llm_datadist_engine.kv_transfer.allocate_cache( + cache_desc, cache_keys) + break + except LLMException as e: + if e.status_code == LLMStatusCode.LLM_DEVICE_OUT_OF_MEMORY: + logger.warning( + f"allocate_cache failed due to insufficient space in the mbuf memory." + ) + time.sleep(0.03) # wait for cache buf to be ready + else: + raise e + cache_buf_addrs = cache_buf.per_device_tensor_addrs[0] + cache_tensors = torchair.llm_datadist.create_npu_tensors( + cache_desc.shape, dtype, cache_buf_addrs) + return cache_buf, cache_tensors + + def _get_unique_req_id(self, request_id: str) -> str: + return f"{request_id}-{self.tp_rank}" + + +# ============================== +# Helper functions +# ============================== + + +def parse_config_string(config_string: str) -> dict: + config_dict = {} + parts = config_string.split(";") + + for part in parts: + if ":" in part: + key, values = part.split(":") + value_parts = values.split(",") + for value_part in value_parts: + if "=" in value_part: + sub_key, sub_value = value_part.split("=") + config_dict[f"{key}_{sub_key}"] = int(sub_value) + else: + sub_key, sub_value = value_part.split("p") + config_dict[f"{key}_{sub_key}p"] = int(sub_value) + + return config_dict + + +def string_to_int64_hash(input_str): + """ + Hash the string using SHA-256 and convert it into an int64 integer. + """ + hashed_bytes = hashlib.sha256(input_str.encode("utf-8")).digest() + trunked_bytes = hashed_bytes[:8] + uint64_value = struct.unpack(" Union[ModelRunnerOutput, torch.Tensor]: + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + get_kv_transfer_group().bind_connector_metadata( + scheduler_output.kv_connector_metadata) + with ProfileExecuteDuration().capture_async( "prepare input and forward"): self._update_states(scheduler_output) From 6dfbaa6bbadc666eb18d5f9faac38b29b2354778 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Sun, 27 Apr 2025 21:15:45 +0800 Subject: [PATCH 02/21] feat: resolve npu_reshape_and_cache error Ensure correct input for npu_reshape_and_cache function The 'slot_indices' parameter of npu_reshape_and_cache must be: - A torch.int32 tensor - Located on the NPU device Signed-off-by: Jade Zheng --- .../distributed/llmdatadist_connector_v1.py | 54 +++++++------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 9f8494b758b..599388caf10 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -334,9 +334,9 @@ def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], block_size: int, is_store: bool) -> "ReqMeta": token_ids_tensor = torch.tensor(token_ids) valid_num_tokens = len(token_ids) - block_ids_tensor = torch.tensor(block_ids) + block_ids_tensor = torch.tensor(block_ids, dtype=torch.int32) num_blocks = block_ids_tensor.shape[0] - block_offsets = torch.arange(0, block_size) + block_offsets = torch.arange(0, block_size, dtype=torch.int32) slot_mapping = block_offsets.reshape( (1, block_size)) + block_ids_tensor.reshape( (num_blocks, 1)) * block_size @@ -495,7 +495,7 @@ def start_load_kv(self, forward_context: "ForwardContext", # NOTE: slen is the len of kv cache need to load for this request # in decode, request_len = prefill_prompt_len + 1 slen = request.token_ids.shape[0] - 1 - cur_slot_mapping = request.slot_mapping[:slen] + req_slot_mapping = request.slot_mapping[:slen].to(device="npu") # For the datadist tensor, the first dimension is 1, the reason can # be found in wait_for_save function @@ -545,7 +545,7 @@ def start_load_kv(self, forward_context: "ForwardContext", for layer_id, kv_cache_layer in enumerate(kv_cache_layers): pulled_kv_cache = pulled_kv_caches[layer_id] self._inject_kv_into_layer(kv_cache_layer, pulled_kv_cache, - cur_slot_mapping, is_mla) + req_slot_mapping, is_mla) # Release the reference count self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer) @@ -686,42 +686,24 @@ def _inject_kv_into_layer( slot_mapping (torch.Tensor): the slot mapping. In shape [num_tokens]. """ - # NOTE: The performance of this function is suboptimal. Using - # `torch_npu._npu_reshape_and_cache` or - # `torch_npu._npu_reshape_and_cache_siso` could improve performance - # significantly. However, attempts to use these methods have failed, and - # the root cause remains unclear. The only available information is an - # error log from the ATB log file, which states: - # "ReshapeAndCacheOperation_1 invalid param, setup check fail, error - # code: 13." - - # The pulled KV cache resides in the mbuf memory space and cannot be - # directly copied to the kv_cache_layer. Therefore, it must first be - # copied to a standard torch tensor using `scatter_update_`. - kv_cache = torch.empty_like(pulled_kv_cache) - indices = torch.tensor([0], dtype=torch.int64, device="npu") - torch_npu.scatter_update_(kv_cache, indices, pulled_kv_cache, axis=-2) # The `wait_for_save` function explains why the first dimension is # necessary. - kv_cache = kv_cache.squeeze(0) - - dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + kv_cache = pulled_kv_cache.squeeze(0) if is_mla: - block_size = dst_kv_cache_layer_shape[1] - num_heads = dst_kv_cache_layer_shape[2] - head_dim = dst_kv_cache_layer_shape[3] - idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size - dst_kv_cache_layer = dst_kv_cache_layer.view( - -1, num_heads, head_dim) - dst_kv_cache_layer[idx_for_copy, ...] = kv_cache + torch_npu._npu_reshape_and_cache_siso( + key=kv_cache, + key_cache=dst_kv_cache_layer, + slot_indices=slot_mapping, + ) + else: - block_size = dst_kv_cache_layer_shape[2] - num_heads = dst_kv_cache_layer_shape[3] - head_dim = dst_kv_cache_layer_shape[4] - idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size - dst_kv_cache_layer = dst_kv_cache_layer.view( - 2, -1, num_heads, head_dim) - dst_kv_cache_layer[:, idx_for_copy, ...] = kv_cache + torch_npu._npu_reshape_and_cache( + key=kv_cache[0], + value=kv_cache[1], + key_cache=dst_kv_cache_layer[0], + value_cache=dst_kv_cache_layer[1], + slot_indices=slot_mapping, + ) def _extract_kv_from_layer( self, From 193ef5a84d8cb3421365185a0638ab7a3ee9dce1 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Mon, 28 Apr 2025 11:19:58 +0800 Subject: [PATCH 03/21] chore: lint code Signed-off-by: Jade Zheng --- .../disagg_prefill_proxy_server.py | 4 +-- .../distributed/llmdatadist_connector_v1.py | 25 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py b/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py index 41ad2b61854..fadf1f0387d 100644 --- a/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py +++ b/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py @@ -4,7 +4,7 @@ import os import aiohttp -from quart import Quart, make_response, request +from quart import Quart, make_response, request # type: ignore AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) @@ -47,7 +47,7 @@ async def handle_request(): }, ): # Print the prefill result - print(f"===== Prefill result =====") + print("===== Prefill result =====") print(prefill_result.decode("utf-8")) print("==========================") response = json.loads(prefill_result.decode("utf-8")) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 599388caf10..f843cd2c998 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -4,12 +4,12 @@ import struct import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import requests import torch import torch_npu -import torchair +import torchair # type: ignore from vllm.distributed import get_tensor_model_parallel_rank, get_world_group from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -67,7 +67,8 @@ class ServerInfo: role: ServerRole devices: List[DeviceInfo] - def get_device(self, tp_rank: int, dp_rank: int) -> DeviceInfo: + def get_device(self, tp_rank: int, + dp_rank: int) -> Union[DeviceInfo, None]: for device in self.devices: if device.tp_rank == tp_rank and device.dp_rank == dp_rank: return device @@ -162,7 +163,7 @@ def __init__(self, vllm_config: "VllmConfig") -> None: GLOBAL_RANKTABLE, self.prefill_tp, self.decode_tp) def get_device(self, server_id: str, dp_rank: int, - tp_rank: int) -> DeviceInfo: + tp_rank: int) -> Union[DeviceInfo, None]: for server in self._servers: if server.server_id != server_id: continue @@ -225,7 +226,7 @@ def _get_first_matching_value(self, config_dict: dict, return default -_CLUSTER_INFO: "ClusterInfo" = None +_CLUSTER_INFO: Optional["ClusterInfo"] = None def init_cluster_info(vllm_config: "VllmConfig") -> None: @@ -272,7 +273,7 @@ def __init__(self, role: llm_datadist.LLMRole, local_rank: int, local_device_info = self.cluster_info.get_device( local_server_id, dp_rank, tp_rank) assert local_device_info is not None, \ - f"Could not find local device from cluster info." + "Could not find local device from cluster info." self.cluster_id = local_device_info.cluster_id self.local_device_ip = local_device_info.device_ip @@ -379,8 +380,8 @@ def __init__(self, vllm_config: "VllmConfig", self.kv_role = llm_datadist.LLMRole.DECODER else: raise ValueError( - f"The value of kv_role must be either `kv_producer` or `kv_consumer`, but received {kv_transfer_config.kv_role}." - ) + "The value of kv_role must be either `kv_producer` or" + f" `kv_consumer`, but received {kv_transfer_config.kv_role}.") # Used by scheduler process self._requests_need_load: dict[str, Request] = {} @@ -400,7 +401,7 @@ def __init__(self, vllm_config: "VllmConfig", "local_server_id", None) assert ( self.local_server_id is not None - ), f"Cannot find `local_server_id` from `kv_transfer_config.kv_connector_extra_config`." + ), "Cannot find `local_server_id` from `kv_transfer_config.kv_connector_extra_config`." self.dp_rank = self._vllm_config.parallel_config.data_parallel_rank self.tp_size = self._vllm_config.parallel_config.tensor_parallel_size @@ -474,7 +475,7 @@ def start_load_kv(self, forward_context: "ForwardContext", # this is the cause. if prefill_infos is None: logger.error( - f"[rank%d][D]: Failed to get prefill info, redo model forwarding.", + "[rank%d][D]: Failed to get prefill info, redo model forwarding.", torch.distributed.get_rank()) return None @@ -853,8 +854,8 @@ def _create_cache_tensors(self, except LLMException as e: if e.status_code == LLMStatusCode.LLM_DEVICE_OUT_OF_MEMORY: logger.warning( - f"allocate_cache failed due to insufficient space in the mbuf memory." - ) + "allocate_cache failed due to insufficient space in the" + " mbuf memory.") time.sleep(0.03) # wait for cache buf to be ready else: raise e From c65fde9ef7013ed053fa56d8bbe8f47ab9160ce9 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Mon, 28 Apr 2025 12:29:31 +0800 Subject: [PATCH 04/21] feat: add offline inference example Signed-off-by: Jade Zheng --- .../offling_inference.py | 201 ++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 examples/disaggregated-prefill-v1/offling_inference.py diff --git a/examples/disaggregated-prefill-v1/offling_inference.py b/examples/disaggregated-prefill-v1/offling_inference.py new file mode 100644 index 00000000000..909d81030b4 --- /dev/null +++ b/examples/disaggregated-prefill-v1/offling_inference.py @@ -0,0 +1,201 @@ +""" +This file demonstrates the example usage of disaggregated prefilling We will +launch 2 vllm instances (NPU 0,1,3,4 for prefill and NPU 5,6,7,8 for decode), +and then transfer the KV cache between them. +""" + +import multiprocessing as mp +import os +from multiprocessing import Event, Process, Queue +from typing import List, Literal + + +def get_kv_transfer_config(role: Literal["kv_producer", "kv_consumer"], + local_server_id: str): + kv_rank = 0 if role == "kv_producer" else 1 + return f"""{{ + "kv_connector": "AscendHcclConnectorV1", + "kv_buffer_device": "npu", + "kv_role": "{role}", + "kv_rank": {kv_rank}, + "kv_parallel_size": 2, + "kv_connector_extra_config": {{ + "local_server_id": "{local_server_id}" + }} + }}""" + + +def clean_up(): + import gc + + import torch + from vllm.distributed.parallel_state import ( + destroy_distributed_environment, destroy_model_parallel) + + destroy_model_parallel() + destroy_distributed_environment() + gc.collect() + torch.npu.empty_cache() + + +def run_prefill( + prefill_done, + process_close, + prompt_q: Queue, + prompts: List[str], + model: str, + local_server_id: str, + visible_devices: str, +): + os.environ["VLLM_USE_V1"] = "1" + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices + tensor_parallel_size = len(visible_devices.split(",")) + + from vllm import LLM, SamplingParams + from vllm.config import KVTransferConfig + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + ktc = KVTransferConfig.from_cli( + get_kv_transfer_config( + role="kv_producer", + local_server_id=local_server_id, + )) + + llm = LLM( + model=model, + trust_remote_code=True, + enforce_eager=True, + enable_prefix_caching=False, + kv_transfer_config=ktc, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=40, + ) + + result = llm.generate(prompts, sampling_params) + for output in result: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"[Prefill] Prompt: {prompt!r}, Generated text: {generated_text!r}" + ) + prompt_q.put(prompt + generated_text) + prompt_q.close() + + print("[Prefill] DONE.") + prefill_done.set() + + # To keep the prefill node running in case the decode node is not done; + # otherwise, the script might exit prematurely, causing incomplete decoding. + process_close.wait() + + del llm + clean_up() + + +def run_decode( + prefill_done, + prompt_q: Queue, + num_prompts: int, + model: str, + local_server_id: str, + visible_devices: str, +): + os.environ["VLLM_USE_V1"] = "1" + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_devices + tensor_parallel_size = len(visible_devices.split(",")) + + from vllm import LLM, SamplingParams + from vllm.config import KVTransferConfig + + sampling_params = SamplingParams(temperature=0, top_p=0.95) + + ktc = KVTransferConfig.from_cli( + get_kv_transfer_config( + role="kv_consumer", + local_server_id=local_server_id, + )) + + llm = LLM( + model=model, + trust_remote_code=True, + enforce_eager=True, + enable_prefix_caching=False, + kv_transfer_config=ktc, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=40, + ) + + # Wait for the producer to start the comsumer + print("[Decode] Waiting for prefill node to finish...") + prefill_done.wait() + + # Get the prompts from the queue + prompts = [] + for _ in range(num_prompts): + prompts.append(prompt_q.get()) + + # At this point when the prefill_done is set, the kv-cache should have been + # transferred to this decode node, so we can start decoding. + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"[Decode] Prompt: {prompt!r}, Generated text: {generated_text!r}") + print("[Decode] DONE.") + + # Must delete the llm instance, otherwise the process will not exit + del llm + clean_up() + + +if __name__ == "__main__": + mp.get_context("spawn") + + model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + # Set the server id and device ids for prefill and decode nodes + prompt_server_id = "server-0" + prompt_deivce_ids = "0,1,2,3" + decode_server_id = "server-1" + decode_device_ids = "4,5,6,7" + + prompts = [ + "Hello, how are you today?", + "Hi, what is your name?", + "Tell me a very long story.", + "what is your favourite book?", + ] + num_prompts = len(prompts) + + prompt_q: Queue = Queue(num_prompts) + prefill_done = Event() + process_close = Event() + + prefill_process = Process( + target=run_prefill, + args=(prefill_done, process_close, prompt_q, prompts, model, + prompt_server_id, prompt_deivce_ids), + ) + decode_process = Process( + target=run_decode, + args=(prefill_done, prompt_q, num_prompts, model, decode_server_id, + decode_device_ids), + ) + + # Start prefill node + prefill_process.start() + # Start decode node + decode_process.start() + + # Wait for decode process to finish + decode_process.join() + print("[Main] Decode process done.") + + # Terminate the prefill node, and wait for it to finish + process_close.set() + prefill_process.join() + print("[Main] Prefill process done.") From 2fbd02913e5ee43293c9b9a361f61f6102bc9be5 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Mon, 28 Apr 2025 14:31:14 +0800 Subject: [PATCH 05/21] feat: simplify 1p1d startup Eliminates the need to launch the meta server in the 1p1d environment. Signed-off-by: Jade Zheng --- .../offling_inference.py | 5 +- .../distributed/llmdatadist_connector_v1.py | 77 +++++++++++++++---- 2 files changed, 62 insertions(+), 20 deletions(-) diff --git a/examples/disaggregated-prefill-v1/offling_inference.py b/examples/disaggregated-prefill-v1/offling_inference.py index 909d81030b4..a908fef2862 100644 --- a/examples/disaggregated-prefill-v1/offling_inference.py +++ b/examples/disaggregated-prefill-v1/offling_inference.py @@ -18,10 +18,7 @@ def get_kv_transfer_config(role: Literal["kv_producer", "kv_consumer"], "kv_buffer_device": "npu", "kv_role": "{role}", "kv_rank": {kv_rank}, - "kv_parallel_size": 2, - "kv_connector_extra_config": {{ - "local_server_id": "{local_server_id}" - }} + "kv_parallel_size": 2 }}""" diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index f843cd2c998..dbd257471bc 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -160,7 +160,12 @@ def __init__(self, vllm_config: "VllmConfig") -> None: str, Any] = kv_transfer_config.get_from_extra_config("decode", {}) self._servers: List[ServerInfo] = get_servers_from_ranktable( - GLOBAL_RANKTABLE, self.prefill_tp, self.decode_tp) + GLOBAL_RANKTABLE, self.prefill_tp_size, self.decode_tp_size) + + self._num_prefill_instances = len( + self.get_servers_by_role(ServerRole.Prefill)) + self._num_decode_instances = len( + self.get_servers_by_role(ServerRole.Decode)) def get_device(self, server_id: str, dp_rank: int, tp_rank: int) -> Union[DeviceInfo, None]: @@ -182,6 +187,10 @@ def get_cluster_id(self, server_id: str, dp_rank: int, def get_servers_by_role(self, role: ServerRole) -> List[ServerInfo]: return [server for server in self._servers if server.role == role] + def is_1p1d(self) -> bool: + return (self._num_prefill_instances == 1 + and self._num_decode_instances == 1) + @property def router_endpoint(self): for server in self._servers: @@ -190,28 +199,28 @@ def router_endpoint(self): raise ValueError("Router endpoint not found") @property - def prefill_dp(self): + def prefill_dp_size(self): candidate_keys = ["data_parallel_size", "dp_size", "dp"] return int( self._get_first_matching_value(self._prefill_parallel_config, candidate_keys, self._dp_size)) @property - def prefill_tp(self): + def prefill_tp_size(self): candidate_keys = ["tensor_parallel_size", "tp_size", "tp"] return int( self._get_first_matching_value(self._prefill_parallel_config, candidate_keys, self._tp_size)) @property - def decode_dp(self): + def decode_dp_size(self): candidate_keys = ["data_parallel_size", "dp_size", "dp"] return int( self._get_first_matching_value(self._decode_parallel_config, candidate_keys, self._dp_size)) @property - def decode_tp(self): + def decode_tp_size(self): candidate_keys = ["tensor_parallel_size", "tp_size", "tp"] return int( self._get_first_matching_value(self._decode_parallel_config, @@ -311,7 +320,8 @@ def make_clusters(self): ServerRole.Prefill): for device in server.devices: target_tp_rank = self.tp_rank % min( - self.cluster_info.prefill_tp, self.cluster_info.decode_tp) + self.cluster_info.prefill_tp_size, + self.cluster_info.decode_tp_size) if target_tp_rank == device.tp_rank: cluster = self.make_cluster(device.device_ip, device.cluster_id) @@ -399,9 +409,21 @@ def __init__(self, vllm_config: "VllmConfig", self.local_server_id = kv_transfer_config.get_from_extra_config( "local_server_id", None) - assert ( - self.local_server_id is not None - ), "Cannot find `local_server_id` from `kv_transfer_config.kv_connector_extra_config`." + if self.local_server_id is None: + if not self.cluster_info.is_1p1d( + ) or self.cluster_info.prefill_dp_size != 1: + raise ValueError( + "Cannot find `local_server_id` from" + " `kv_transfer_config.kv_connector_extra_config`.") + # In a 1p1d configuration (1 prefill node and 1 decode node), the + # server ID can be directly determined from the rank table based on + # the KV role. + servers = self.cluster_info.get_servers_by_role( + ServerRole.Prefill if self.kv_role == + llm_datadist.LLMRole.PROMPT else ServerRole.Decode) + assert len(servers) == 1, \ + f"Expected only one server for {self.kv_role}, but got {len(servers)}" + self.local_server_id = servers[0].server_id self.dp_rank = self._vllm_config.parallel_config.data_parallel_rank self.tp_size = self._vllm_config.parallel_config.tensor_parallel_size @@ -467,8 +489,25 @@ def start_load_kv(self, forward_context: "ForwardContext", self._get_unique_req_id(req.request_id) for req in metadata.requests if not req.is_store ] - prefill_infos = fetch_prefill_info(self.cluster_info.router_endpoint, - request_ids) + if self.cluster_info.is_1p1d( + ) and self.cluster_info.prefill_dp_size == 1: + # In a 1p1d configuration (1 prefill node and 1 decode node), the + # server ID can be directly determined from the rank table based on + # the KV role. + servers = self.cluster_info.get_servers_by_role(ServerRole.Prefill) + assert len(servers) == 1, \ + f"Expected only one server for {self.kv_role}, but got {len(servers)}" + prefill_infos = { + request_id: { + "dp_rank": 0, + "server_id": servers[0].server_id, + } + for request_id in request_ids + } + else: + prefill_infos = fetch_prefill_info( + self.cluster_info.router_endpoint, request_ids) + # If prefill_infos is None, it indicates that get_prefill_info failed. # Therefore, we need to recalculate the kv cache during the decoding # phase. If there is a performance issue, we should consider whether @@ -518,8 +557,8 @@ def start_load_kv(self, forward_context: "ForwardContext", self.num_layers, kv_cache_shape, kv_hidden_dtype) target_tp_rank = self.tp_rank % min( - self.cluster_info.prefill_tp, - self.cluster_info.decode_tp, + self.cluster_info.prefill_tp_size, + self.cluster_info.decode_tp_size, ) remote_cluster_id = self.cluster_info.get_cluster_id( server_id, dp_rank, target_tp_rank) @@ -662,9 +701,15 @@ def wait_for_save(self): # Release reference count self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer) - # Report prefill info to meta server - report_prefill_info(self.cluster_info.router_endpoint, - prefill_info_input) + # If the cluster is configured as 1p1d (1 prefill node and 1 decode + # node), and the data parallel size on the prefill node is 1, we don't + # need to report the prefill information to the router. This is because + # there is only one candidate server for the decode node to request the + # KV cache from. + if not self.cluster_info.is_1p1d( + ) or self.cluster_info.prefill_dp_size != 1: + report_prefill_info(self.cluster_info.router_endpoint, + prefill_info_input) logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) def _inject_kv_into_layer( From 6311b1b232c59f5d0b4922a5e82b77866e536026 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Mon, 28 Apr 2025 14:37:35 +0800 Subject: [PATCH 06/21] chore: rename file Signed-off-by: Jade Zheng --- .../{offling_inference.py => offline_inference.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/disaggregated-prefill-v1/{offling_inference.py => offline_inference.py} (100%) diff --git a/examples/disaggregated-prefill-v1/offling_inference.py b/examples/disaggregated-prefill-v1/offline_inference.py similarity index 100% rename from examples/disaggregated-prefill-v1/offling_inference.py rename to examples/disaggregated-prefill-v1/offline_inference.py From d976098adb8b34d34625ef685764600164339849 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Mon, 28 Apr 2025 15:20:31 +0800 Subject: [PATCH 07/21] chore: lint code Signed-off-by: Jade Zheng --- vllm_ascend/distributed/llmdatadist_connector_v1.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index dbd257471bc..6fe55d845b3 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -4,7 +4,7 @@ import struct import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import requests import torch @@ -497,7 +497,7 @@ def start_load_kv(self, forward_context: "ForwardContext", servers = self.cluster_info.get_servers_by_role(ServerRole.Prefill) assert len(servers) == 1, \ f"Expected only one server for {self.kv_role}, but got {len(servers)}" - prefill_infos = { + prefill_infos: Dict[str, Any] = { request_id: { "dp_rank": 0, "server_id": servers[0].server_id, @@ -548,8 +548,8 @@ def start_load_kv(self, forward_context: "ForwardContext", kv_cache_shape = (1, 2, slen, num_heads, head_dim) uniq_req_id = self._get_unique_req_id(request.request_id) - dp_rank = prefill_infos[uniq_req_id]["dp_rank"] - server_id = prefill_infos[uniq_req_id]["server_id"] + dp_rank: int = prefill_infos[uniq_req_id]["dp_rank"] + server_id: str = prefill_infos[uniq_req_id]["server_id"] # pull kv cache from prefill node by request kv_hidden_dtype = kv_cache_layers[0].dtype From 2d698e3be7eb317ffd3190467fd72565a8985d66 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Mon, 28 Apr 2025 15:45:24 +0800 Subject: [PATCH 08/21] fix: resolve import issue when running with vllm 0.8.4 Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e1bf0c618dc..20ad123cd0a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -79,6 +79,17 @@ from vllm_ascend.utils import ProfileExecuteDuration from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer +if vllm_version_is("0.8.4"): + from vllm.distributed import get_kv_transfer_group + + def has_kv_transfer_group() -> bool: + # vLLM 0.8.4 does not support disaggregated prefill. This function is + # added to ensure compatibility with vLLM 0.8.4. + return False +else: + from vllm.distributed.kv_transfer import ( # type: ignore + get_kv_transfer_group, has_kv_transfer_group) + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput From 3c829a3dcbc4ce16f4c99be4dffbaf46e1088104 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Tue, 6 May 2025 09:03:42 +0800 Subject: [PATCH 09/21] chore: remove v0.8.4 patch Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 20ad123cd0a..e1bf0c618dc 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -79,17 +79,6 @@ from vllm_ascend.utils import ProfileExecuteDuration from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer -if vllm_version_is("0.8.4"): - from vllm.distributed import get_kv_transfer_group - - def has_kv_transfer_group() -> bool: - # vLLM 0.8.4 does not support disaggregated prefill. This function is - # added to ensure compatibility with vLLM 0.8.4. - return False -else: - from vllm.distributed.kv_transfer import ( # type: ignore - get_kv_transfer_group, has_kv_transfer_group) - if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput From 385ed31a327d6b15bb62480558856c6a41518576 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 7 May 2025 22:10:40 +0800 Subject: [PATCH 10/21] chore: refine the init optins of datadist Signed-off-by: Jade Zheng --- vllm_ascend/distributed/llmdatadist_connector_v1.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 6fe55d845b3..9456133b4ec 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -290,18 +290,15 @@ def __init__(self, role: llm_datadist.LLMRole, local_rank: int, self.role, self.cluster_id) def prepare_data_dist(self): + # TODO: The maximum size of the mbuf for the llm datadist. We need to + # find an appropriate value to minimize memory waste. options = { "llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME, + "ge.flowGraphMemMaxSize": f"{(9*1024*1024*1024):d}", + "ge.exec.deviceId": str(self.local_rank), } if self.role == llm_datadist.LLMRole.PROMPT: - # TODO: This represents the maximum size of the mbuf for the llm - # datadist. We need to find an appropriate value to minimize memory - # waste. - # options["ge.flowGraphMemMaxSize"] = "1024" # MB - options["ge.exec.deviceId"] = str(self.local_rank) options["llm.listenIpInfo"] = f"{self.local_device_ip}:26000" - else: - options["ge.exec.deviceId"] = str(self.local_rank) self.datadist_engine.init(options) self.kv_transfer = self.datadist_engine.kv_cache_manager From 9e77d4c06cabc09e6a4ed49aebdf7d0d29cb20e0 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 7 May 2025 22:12:30 +0800 Subject: [PATCH 11/21] feat: refine linking logic Signed-off-by: Jade Zheng --- .../distributed/llmdatadist_connector_v1.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 9456133b4ec..6ff282c070d 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -434,19 +434,25 @@ def __init__(self, vllm_config: "VllmConfig", self.local_server_id) self.llm_datadist_engine.prepare_data_dist() if self.kv_role == llm_datadist.LLMRole.DECODER: - while True: - try: - # Each decoding rank should correspond to each prefilling rank. - clusters = self.llm_datadist_engine.make_clusters() - _, ret = self.llm_datadist_engine.datadist_engine.link_clusters( - clusters, 20000) - logger.info(f"{local_rank} link, ret={ret}") - break - except LLMException as e: - logger.error( - f"Failed to link clusters, local_rank {local_rank}, error: {e}" - ) - time.sleep(1) + # Each decoding rank should correspond to each prefilling rank. + clusters = self.llm_datadist_engine.make_clusters() + while len(clusters) > 0: + overall_ret, link_rets = \ + self.llm_datadist_engine.datadist_engine.link_clusters( + clusters, timeout=3000) + + if overall_ret != LLMStatusCode.LLM_SUCCESS: + logger.warning(f"Failed to link clusters, {overall_ret=}") + continue + + for idx, link_ret in enumerate(link_rets): + if link_ret == LLMStatusCode.LLM_SUCCESS: + clusters.pop(idx) + + if len(clusters) == 0: + logger.info("Successfully linked clusters") + else: + logger.warning(f"Still {len(clusters)} clusters to link") def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: From 8bf8459643ff0d1ff77df3e9065a38784df72346 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Thu, 8 May 2025 17:35:28 +0800 Subject: [PATCH 12/21] fix: manage KV cache buffer lifecycle to prevent premature deallocation Signed-off-by: Jade Zheng --- .../distributed/llmdatadist_connector_v1.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 6ff282c070d..ea080740958 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -454,6 +454,22 @@ def __init__(self, vllm_config: "VllmConfig", else: logger.warning(f"Still {len(clusters)} clusters to link") + # LLMDataDist will deallocate the cache buffer either when the cache + # buffer's Python object goes out of scope or when deallocate_cache() is + # explicitly called. This can lead to accuracy issues if the cache + # buffer is deallocated while still being used in the NPU stream. To + # prevent this, we maintain a reference to the cache buffer until the + # next round, ensuring it is not prematurely deallocated. + self.kv_buffers: List = [] + + def _detach_kv_buffers(self): + for kv_buffer in self.kv_buffers: + self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer) + self.kv_buffers.clear() + + def _attach_kv_buffer(self, kv_buffer: torch.Tensor): + self.kv_buffers.append(kv_buffer) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: """ @@ -477,6 +493,9 @@ def start_load_kv(self, forward_context: "ForwardContext", # In the prefilling node, do not need to load KV cache. return + # Release the KV cache buffer from the previous round + self._detach_kv_buffers() + # Get the metadata metadata = self._get_connector_metadata() assert isinstance(metadata, LLMDataDistConnectorV1Metadata) @@ -558,6 +577,7 @@ def start_load_kv(self, forward_context: "ForwardContext", kv_hidden_dtype = kv_cache_layers[0].dtype kv_buffer, pulled_kv_caches = self._create_cache_tensors( self.num_layers, kv_cache_shape, kv_hidden_dtype) + self._attach_kv_buffer(kv_buffer) target_tp_rank = self.tp_rank % min( self.cluster_info.prefill_tp_size, @@ -590,9 +610,6 @@ def start_load_kv(self, forward_context: "ForwardContext", self._inject_kv_into_layer(kv_cache_layer, pulled_kv_cache, req_slot_mapping, is_mla) - # Release the reference count - self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer) - def wait_for_layer_load(self, layer_name: str) -> None: """ Block until the KV for a specific layer is loaded into vLLM's From 945aa93f6cb71c5b8c77da5761f60514f4d9c034 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Fri, 9 May 2025 10:53:56 +0800 Subject: [PATCH 13/21] fix: correct finding the kv cache shape for mha Signed-off-by: Jade Zheng --- vllm_ascend/distributed/llmdatadist_connector_v1.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index ea080740958..60ac390d898 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -785,11 +785,12 @@ def _extract_kv_from_layer( """ if is_mla: num_heads, head_dim = kv_cache_layer.shape[ - 2], kv_cache_layer.shape[3] + -2], kv_cache_layer.shape[-1] return kv_cache_layer.view(-1, num_heads, head_dim)[slot_mapping, ...] - num_heads, head_dim = kv_cache_layer.shape[2], kv_cache_layer.shape[3] + num_heads, head_dim = kv_cache_layer.shape[-2], kv_cache_layer.shape[ + -1] return kv_cache_layer.view(2, -1, num_heads, head_dim)[:, slot_mapping, ...] From 449b14e4ee709d286122712bafe7ddb2816b5c2d Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Sun, 11 May 2025 10:53:33 +0800 Subject: [PATCH 14/21] fix: reverse iteration over link_rets to safely remove clusters Signed-off-by: Jade Zheng --- vllm_ascend/distributed/llmdatadist_connector_v1.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 60ac390d898..94e90368e53 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -445,7 +445,8 @@ def __init__(self, vllm_config: "VllmConfig", logger.warning(f"Failed to link clusters, {overall_ret=}") continue - for idx, link_ret in enumerate(link_rets): + for idx in range(len(link_rets) - 1, -1, -1): + link_ret = link_rets[idx] if link_ret == LLMStatusCode.LLM_SUCCESS: clusters.pop(idx) From 36a198398a574f3434f4bb5d3e117d3e9a72cd02 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 14 May 2025 12:06:39 +0800 Subject: [PATCH 15/21] fix: correct block_ids assignment Signed-off-by: Jade Zheng --- vllm_ascend/distributed/llmdatadist_connector_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 94e90368e53..d8fe6fb4938 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -888,7 +888,7 @@ def build_connector_meta( meta.add_request( request_id=cached_req.req_id, token_ids=token_ids, - block_ids=new_req.block_ids, + block_ids=cached_req.block_ids, block_size=self._block_size, is_store=False, ) From 715a569f6d649913d2c0d7cdadf50f021e52202a Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 14 May 2025 12:09:53 +0800 Subject: [PATCH 16/21] fix: typo Signed-off-by: Jade Zheng --- examples/disaggregated-prefill-v1/offline_inference.py | 2 +- vllm_ascend/distributed/llmdatadist_connector_v1.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/disaggregated-prefill-v1/offline_inference.py b/examples/disaggregated-prefill-v1/offline_inference.py index a908fef2862..53915056512 100644 --- a/examples/disaggregated-prefill-v1/offline_inference.py +++ b/examples/disaggregated-prefill-v1/offline_inference.py @@ -125,7 +125,7 @@ def run_decode( max_model_len=40, ) - # Wait for the producer to start the comsumer + # Wait for the producer to start the consumer print("[Decode] Waiting for prefill node to finish...") prefill_done.wait() diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index d8fe6fb4938..7763d85794b 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -669,7 +669,7 @@ def wait_for_save(self): indices = torch.tensor([0], dtype=torch.int64, device="npu") prefill_info_input = {} - # kv cache should be transfered by request + # kv cache should be transferred by request for _, request in enumerate(metadata.requests): if not request.is_store: continue @@ -820,7 +820,7 @@ def get_num_new_matched_tokens( # NOTE: only request in waiting queue will come here. we use datadist # pull cache to do transfer, so we don't align to block_size in prefill, # we won't have extra new matched tokens; in decode, new request kv - # cache will be transfered from prefill, so num_computed_tokens = 0, and + # cache will be transferred from prefill, so num_computed_tokens = 0, and # extra new matched tokens should be len(request.prompt_token_ids) - 1 if self.kv_role == llm_datadist.LLMRole.PROMPT: return 0 From 969159eda9f5b11447f09c2aa57364b10ccf4a6c Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Thu, 22 May 2025 14:16:41 +0800 Subject: [PATCH 17/21] avoid memory lack Signed-off-by: Jade Zheng --- .../distributed/llmdatadist_connector_v1.py | 296 ++++++++++++------ 1 file changed, 199 insertions(+), 97 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 7763d85794b..9bf1c29d31f 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -1,6 +1,7 @@ import enum import hashlib import json +import random import struct import time from dataclasses import dataclass @@ -14,7 +15,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.forward_context import get_forward_context -from vllm.logger import init_logger +from vllm.logger import logger from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: @@ -29,8 +30,6 @@ import vllm_ascend.envs as envs from vllm_ascend.attention.mla_v1 import AscendMLAMetadata -logger = init_logger(__name__) - TORCH_DTYPE_TO_NPU_DTYPE = { torch.half: llm_datadist.DataType.DT_FLOAT16, torch.float16: llm_datadist.DataType.DT_FLOAT16, @@ -294,12 +293,13 @@ def prepare_data_dist(self): # find an appropriate value to minimize memory waste. options = { "llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME, - "ge.flowGraphMemMaxSize": f"{(9*1024*1024*1024):d}", + "ge.flowGraphMemMaxSize": f"{int(2.25*1024*1024*1024):d}", "ge.exec.deviceId": str(self.local_rank), } if self.role == llm_datadist.LLMRole.PROMPT: options["llm.listenIpInfo"] = f"{self.local_device_ip}:26000" self.datadist_engine.init(options) + logger.info("llm_datadist init done") self.kv_transfer = self.datadist_engine.kv_cache_manager def make_cluster(self, prefill_ip, cluster_id=-1): @@ -316,13 +316,9 @@ def make_clusters(self): for server in self.cluster_info.get_servers_by_role( ServerRole.Prefill): for device in server.devices: - target_tp_rank = self.tp_rank % min( - self.cluster_info.prefill_tp_size, - self.cluster_info.decode_tp_size) - if target_tp_rank == device.tp_rank: - cluster = self.make_cluster(device.device_ip, - device.cluster_id) - clusters.append(cluster) + cluster = self.make_cluster(device.device_ip, + device.cluster_id) + clusters.append(cluster) return clusters @@ -404,6 +400,13 @@ def __init__(self, vllm_config: "VllmConfig", init_cluster_info(self._vllm_config) self.cluster_info = get_cluster_info() + if self.cluster_info.prefill_tp_size < self.cluster_info.decode_tp_size: + raise ValueError( + "The prefill tensor parallel size must be greater than or " + f"equal to the decode tensor parallel size, but got " + f"{self.cluster_info.prefill_tp_size} < " + f"{self.cluster_info.decode_tp_size}.") + self.local_server_id = kv_transfer_config.get_from_extra_config( "local_server_id", None) if self.local_server_id is None: @@ -422,12 +425,13 @@ def __init__(self, vllm_config: "VllmConfig", f"Expected only one server for {self.kv_role}, but got {len(servers)}" self.local_server_id = servers[0].server_id - self.dp_rank = self._vllm_config.parallel_config.data_parallel_rank + self.dp_rank = self._vllm_config.parallel_config.data_parallel_rank_local self.tp_size = self._vllm_config.parallel_config.tensor_parallel_size self.tp_rank = get_tensor_model_parallel_rank() self.num_layers = self._vllm_config.model_config.get_num_layers( self._vllm_config.parallel_config) - + if self.tp_size == 1: + local_rank = self.dp_rank local_rank = get_world_group().local_rank self.llm_datadist_engine = KVTransferEngine(self.kv_role, local_rank, self.dp_rank, self.tp_rank, @@ -436,24 +440,23 @@ def __init__(self, vllm_config: "VllmConfig", if self.kv_role == llm_datadist.LLMRole.DECODER: # Each decoding rank should correspond to each prefilling rank. clusters = self.llm_datadist_engine.make_clusters() - while len(clusters) > 0: - overall_ret, link_rets = \ - self.llm_datadist_engine.datadist_engine.link_clusters( - clusters, timeout=3000) - - if overall_ret != LLMStatusCode.LLM_SUCCESS: - logger.warning(f"Failed to link clusters, {overall_ret=}") - continue - - for idx in range(len(link_rets) - 1, -1, -1): - link_ret = link_rets[idx] - if link_ret == LLMStatusCode.LLM_SUCCESS: - clusters.pop(idx) - - if len(clusters) == 0: - logger.info("Successfully linked clusters") - else: - logger.warning(f"Still {len(clusters)} clusters to link") + random.shuffle(clusters) + for cluster in clusters: + while True: + link_ret, link_rets = \ + self.llm_datadist_engine.datadist_engine.link_clusters( + [cluster], timeout=30_000) + + if link_ret == LLMStatusCode.LLM_SUCCESS \ + and link_rets[0] == LLMStatusCode.LLM_SUCCESS: + break + + sleep_time = random.uniform(5, 17) + logger.warning( + f"Failed to link cluster({cluster.remote_cluster_id}), " + f"retrying in {sleep_time:.2f} seconds") + time.sleep(sleep_time) + logger.info("Successfully linked clusters") # LLMDataDist will deallocate the cache buffer either when the cache # buffer's Python object goes out of scope or when deallocate_cache() is @@ -463,6 +466,18 @@ def __init__(self, vllm_config: "VllmConfig", # next round, ensuring it is not prematurely deallocated. self.kv_buffers: List = [] + # In graph mode (migrated from v0), the layer KV cache format differs + # from the v1 format. As a result, the KV transfer process requires + # specific handling to accommodate this difference. + additional_config = self._vllm_config.additional_config + self.enable_graph_mode = additional_config and additional_config.get( + "enable_graph_mode", False) + + if self.enable_graph_mode and \ + self.kv_role == llm_datadist.LLMRole.PROMPT: + raise NotImplementedError( + "The graph mode is not supported for prefill node now.") + def _detach_kv_buffers(self): for kv_buffer in self.kv_buffers: self.llm_datadist_engine.kv_transfer.deallocate_cache(kv_buffer) @@ -508,9 +523,22 @@ def start_load_kv(self, forward_context: "ForwardContext", attn_metadata = forward_context.attn_metadata assert attn_metadata is not None, "The attn_metadata should not be None." - request_ids = [ - self._get_unique_req_id(req.request_id) - for req in metadata.requests if not req.is_store + request_metadatas = {} + for request in metadata.requests: + if request.is_store: + continue + datadist_request_id = request_id_hex_to_number(request.request_id) + target_tp_rank = self._get_target_tp_rank(datadist_request_id) + unique_req_id = self._get_unique_req_id(request.request_id, + target_tp_rank) + request_metadatas[request.request_id] = { + "datadist_request_id": datadist_request_id, + "target_tp_rank": target_tp_rank, + "unique_req_id": unique_req_id, + } + + unique_req_ids = [ + meta["unique_req_id"] for meta in request_metadatas.values() ] if self.cluster_info.is_1p1d( ) and self.cluster_info.prefill_dp_size == 1: @@ -521,15 +549,15 @@ def start_load_kv(self, forward_context: "ForwardContext", assert len(servers) == 1, \ f"Expected only one server for {self.kv_role}, but got {len(servers)}" prefill_infos: Dict[str, Any] = { - request_id: { + unique_req_id: { "dp_rank": 0, "server_id": servers[0].server_id, } - for request_id in request_ids + for unique_req_id in unique_req_ids } else: prefill_infos = fetch_prefill_info( - self.cluster_info.router_endpoint, request_ids) + self.cluster_info.router_endpoint, unique_req_ids) # If prefill_infos is None, it indicates that get_prefill_info failed. # Therefore, we need to recalculate the kv cache during the decoding @@ -547,8 +575,27 @@ def start_load_kv(self, forward_context: "ForwardContext", forward_context.virtual_engine] kv_cache_layers.append(kv_cache_layer) + if self.enable_graph_mode: + # Currently, the graph mode is migrated from the v0, and the kv + # cache layer is a tuple. The first element is + # 'layer_kv_cache_nope', and the second element is + # 'layer_kv_cache_pe'. + assert isinstance(kv_cache_layers[0], tuple) and \ + len(kv_cache_layers[0]) == 2, ( + "The kv_cache_layer should be a tuple of two tensors for " + "current graph mode.") + layer_kv_cache_nope_shape = kv_cache_layers[0][0].shape + layer_kv_cache_pe_shape = kv_cache_layers[0][1].shape + kv_lora_rank = layer_kv_cache_nope_shape[-1] + qk_rope_head_dim = layer_kv_cache_pe_shape[-1] + kv_cache_layer_shape = list(layer_kv_cache_nope_shape[:-1]) + \ + [kv_lora_rank + qk_rope_head_dim] + kv_hidden_dtype = kv_cache_layers[0][0].dtype + else: + kv_cache_layer_shape = list(kv_cache_layers[0].shape) + kv_hidden_dtype = kv_cache_layers[0].dtype + is_mla = isinstance(attn_metadata, AscendMLAMetadata) - kv_cache_layer_shape = list(kv_cache_layers[0].shape) num_heads = int(kv_cache_layer_shape[-2]) head_dim = int(kv_cache_layer_shape[-1]) # Load the KV for each request each layer @@ -570,28 +617,24 @@ def start_load_kv(self, forward_context: "ForwardContext", # [1, 2, slen, num_heads, head_dim] kv_cache_shape = (1, 2, slen, num_heads, head_dim) - uniq_req_id = self._get_unique_req_id(request.request_id) - dp_rank: int = prefill_infos[uniq_req_id]["dp_rank"] - server_id: str = prefill_infos[uniq_req_id]["server_id"] - - # pull kv cache from prefill node by request - kv_hidden_dtype = kv_cache_layers[0].dtype - kv_buffer, pulled_kv_caches = self._create_cache_tensors( - self.num_layers, kv_cache_shape, kv_hidden_dtype) - self._attach_kv_buffer(kv_buffer) + # Each request uses the same llm_datadist request_id, which needs to + # be converted into an integer value. + request_metadata = request_metadatas[request.request_id] + datadist_request_id = request_metadata["datadist_request_id"] + target_tp_rank = request_metadata["target_tp_rank"] + unique_req_id = request_metadata["unique_req_id"] - target_tp_rank = self.tp_rank % min( - self.cluster_info.prefill_tp_size, - self.cluster_info.decode_tp_size, - ) + dp_rank: int = prefill_infos[unique_req_id]["dp_rank"] + server_id: str = prefill_infos[unique_req_id]["server_id"] remote_cluster_id = self.cluster_info.get_cluster_id( server_id, dp_rank, target_tp_rank) - - # Each request uses the same llm_datadist request_id, which needs to - # be converted into an integer value. - datadist_request_id = string_to_int64_hash(request.request_id) kv_cache_key = llm_datadist.CacheKey(remote_cluster_id, datadist_request_id, 1) + + # pull kv cache from prefill node by request + kv_buffer, pulled_kv_caches = self._create_cache_tensors( + self.num_layers, kv_cache_shape, kv_hidden_dtype) + self._attach_kv_buffer(kv_buffer) self.llm_datadist_engine.kv_transfer.pull_cache( kv_cache_key, kv_buffer, 0) @@ -608,8 +651,18 @@ def start_load_kv(self, forward_context: "ForwardContext", for layer_id, kv_cache_layer in enumerate(kv_cache_layers): pulled_kv_cache = pulled_kv_caches[layer_id] - self._inject_kv_into_layer(kv_cache_layer, pulled_kv_cache, - req_slot_mapping, is_mla) + + if self.enable_graph_mode: + pulled_kv_cache = torch.split(pulled_kv_cache, + dim=-1, + split_size_or_sections=[ + kv_lora_rank, + qk_rope_head_dim + ]) + + self._inject_kv_into_layer( + kv_cache_layer, pulled_kv_cache, req_slot_mapping, is_mla + and not self.enable_graph_mode) def wait_for_layer_load(self, layer_name: str) -> None: """ @@ -652,12 +705,14 @@ def wait_for_save(self): # layer operations, all send is done in wait_for_save if self.kv_role == llm_datadist.LLMRole.DECODER: - # In the prompt role, we do not need to load KV cache. + # In the decoder role, we do not need to save KV cache. return forward_context = get_forward_context() metadata: KVConnectorMetadata = self._get_connector_metadata() - assert isinstance(metadata, LLMDataDistConnectorV1Metadata) + assert isinstance(metadata, LLMDataDistConnectorV1Metadata), \ + ("metadata should be LLMDataDistConnectorV1Metadata, but got " + f"{type(metadata)}.") attn_metadata = forward_context.attn_metadata if attn_metadata is None: @@ -677,12 +732,27 @@ def wait_for_save(self): slen = request.token_ids.shape[0] req_slot_mapping = request.slot_mapping[:slen] - uniq_req_id = self._get_unique_req_id(request.request_id) + uniq_req_id = self._get_unique_req_id(request.request_id, + self.tp_rank) prefill_info_input[uniq_req_id] = { "dp_rank": self.dp_rank, "server_id": self.local_server_id, } + # Initialize LLMDatadist data structure. Each request uses the same + # llm_datadist request_id, which needs to be converted to an integer + # value. + datadist_request_id = request_id_hex_to_number(request.request_id) + kv_cache_keys = [ + llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, + datadist_request_id, 1) + ] + + if not self._need_save_kv(datadist_request_id): + # We choose some ranks to save kv cache randomly, if the rank is + # not selected, we do not need to save kv cache. + continue + kv_caches: List[torch.Tensor] = [] for _, attn_layer in forward_context.no_compile_layers.items(): kv_cache_layer = attn_layer.kv_cache[ @@ -692,15 +762,6 @@ def wait_for_save(self): is_mla) kv_caches.append(kv_cache.detach()) - # Initialize LLMDatadist data structure. Each request uses the same - # llm_datadist request_id, which needs to be converted to an integer - # value. - datadist_request_id = string_to_int64_hash(request.request_id) - kv_cache_keys = [ - llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, - datadist_request_id, 1) - ] - # If MLA is used, the kv_cache_shape should be (1, slen, num_heads, # head_dim). Otherwise, it should be (1, 2, slen, num_heads, # head_dim). The first dimension must be 1, because the following @@ -727,18 +788,21 @@ def wait_for_save(self): # need to report the prefill information to the router. This is because # there is only one candidate server for the decode node to request the # KV cache from. - if not self.cluster_info.is_1p1d( - ) or self.cluster_info.prefill_dp_size != 1: + if len(prefill_info_input) > 0 and ( + not self.cluster_info.is_1p1d() + or self.cluster_info.prefill_dp_size != 1): report_prefill_info(self.cluster_info.router_endpoint, prefill_info_input) logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) def _inject_kv_into_layer( self, - dst_kv_cache_layer: torch.Tensor, - pulled_kv_cache: torch.Tensor, + dst_kv_cache_layer: Union[torch.Tensor, tuple[torch.Tensor, + torch.Tensor]], + pulled_kv_cache: Union[torch.Tensor, tuple[torch.Tensor, + torch.Tensor]], slot_mapping: torch.Tensor, - is_mla: bool, + use_siso: bool, ) -> None: """Inject the KV cache into the layer. @@ -755,8 +819,12 @@ def _inject_kv_into_layer( """ # The `wait_for_save` function explains why the first dimension is # necessary. - kv_cache = pulled_kv_cache.squeeze(0) - if is_mla: + if isinstance(pulled_kv_cache, tuple): + kv_cache = [cache.squeeze(0) for cache in pulled_kv_cache] + else: + kv_cache = pulled_kv_cache.squeeze(0) + + if use_siso: torch_npu._npu_reshape_and_cache_siso( key=kv_cache, key_cache=dst_kv_cache_layer, @@ -764,6 +832,8 @@ def _inject_kv_into_layer( ) else: + kv_cache[0] = kv_cache[0].contiguous() + kv_cache[1] = kv_cache[1].contiguous() torch_npu._npu_reshape_and_cache( key=kv_cache[0], value=kv_cache[1], @@ -820,8 +890,9 @@ def get_num_new_matched_tokens( # NOTE: only request in waiting queue will come here. we use datadist # pull cache to do transfer, so we don't align to block_size in prefill, # we won't have extra new matched tokens; in decode, new request kv - # cache will be transferred from prefill, so num_computed_tokens = 0, and - # extra new matched tokens should be len(request.prompt_token_ids) - 1 + # cache will be transferred from prefill, so num_computed_tokens = 0, + # and extra new matched tokens should be len(request.prompt_token_ids) - + # 1 if self.kv_role == llm_datadist.LLMRole.PROMPT: return 0 return len(request.prompt_token_ids) - 1 @@ -876,23 +947,9 @@ def build_connector_meta( # the first N requests in the list scheduled_cache_reqs. if not cached_req.resumed_from_preemption: break - if cached_req.req_id in self._requests_need_load: - # NOTE(rob): cached_req_data does not have the full - # list of token ids (only new tokens). So we look it - # up in the actual request object. - request = self._requests_need_load[cached_req.req_id] - total_tokens = len( - cached_req.new_token_ids) + cached_req.num_computed_tokens - token_ids = request.all_token_ids[:total_tokens] - - meta.add_request( - request_id=cached_req.req_id, - token_ids=token_ids, - block_ids=cached_req.block_ids, - block_size=self._block_size, - is_store=False, - ) - total_need_load += 1 + raise NotImplementedError( + "Resumed requests are not supported in this version of the " + "connector.") assert total_need_load == len(self._requests_need_load) self._requests_need_load.clear() @@ -931,8 +988,33 @@ def _create_cache_tensors(self, cache_desc.shape, dtype, cache_buf_addrs) return cache_buf, cache_tensors - def _get_unique_req_id(self, request_id: str) -> str: - return f"{request_id}-{self.tp_rank}" + def _get_unique_req_id(self, request_id: str, tp_rank: int) -> str: + return f"{request_id}-{tp_rank}" + + def _get_prefill_tp_ranks_for_req(self, datadist_req_id: int) -> list[int]: + """Based on the LLMDataDist request id, select a subset of tensor + parallel ranks. Specifically, choose `decode_tp_size` ranks randomly + from all available prefill TP ranks. These selected ranks are + responsible for saving the KV cache for the current request.""" + if self.cluster_info.prefill_tp_size == self.cluster_info.decode_tp_size: + return list(range(self.cluster_info.prefill_tp_size)) + + rand = random.Random(datadist_req_id) + sampled_nums = rand.sample(range(self.cluster_info.prefill_tp_size), + self.cluster_info.decode_tp_size) + return sampled_nums + + def _need_save_kv(self, datadist_req_id: int) -> bool: + """Determines whether the current rank needs to save the KV cache for a + given LLMDataDist request ID.""" + return self.tp_rank in self._get_prefill_tp_ranks_for_req( + datadist_req_id) + + def _get_target_tp_rank(self, datadist_req_id: int) -> int: + """Determines the target tensor parallel (TP) rank for a given TP rank + and LLMDataDist request ID.""" + return self._get_prefill_tp_ranks_for_req(datadist_req_id)[ + self.tp_rank] # ============================== @@ -967,3 +1049,23 @@ def string_to_int64_hash(input_str): trunked_bytes = hashed_bytes[:8] uint64_value = struct.unpack(" Date: Sun, 15 Jun 2025 17:20:28 +0800 Subject: [PATCH 18/21] chore: lint code Signed-off-by: Jade Zheng --- vllm_ascend/envs.py | 1 - vllm_ascend/worker/model_runner_v1.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 21b95e87d9a..ec97ad7f253 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -133,7 +133,6 @@ # value to False to disable the optimized model. "USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))), - "GLOBAL_RANKTABLE": lambda: os.getenv("GLOBAL_RANKTABLE", None) } diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e1bf0c618dc..c4512929298 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -37,9 +37,9 @@ from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_dp_group, get_pp_group from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.parallel_state import get_dp_group, get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger From 02e9248d23c54d8b321ba5d02464e1464ce295ee Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Sun, 15 Jun 2025 17:32:26 +0800 Subject: [PATCH 19/21] fix: allocation failure for cache_tensor despite sufficient mbuf Signed-off-by: Jade Zheng --- .../distributed/llmdatadist_connector_v1.py | 31 ++++++++++--------- vllm_ascend/envs.py | 11 +++++-- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 9bf1c29d31f..296375e67c6 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -25,9 +25,9 @@ from vllm.v1.request import Request import llm_datadist # type: ignore -from llm_datadist import LLMException, LLMStatusCode +from llm_datadist import LLMConfig, LLMException, LLMStatusCode -import vllm_ascend.envs as envs +import vllm_ascend.envs as envs_ascend from vllm_ascend.attention.mla_v1 import AscendMLAMetadata TORCH_DTYPE_TO_NPU_DTYPE = { @@ -41,7 +41,7 @@ torch.int32: llm_datadist.DataType.DT_INT32, } -GLOBAL_RANKTABLE = envs.GLOBAL_RANKTABLE +GLOBAL_RANKTABLE = envs_ascend.LLMDATADIST_GLOBAL_RANKTABLE class ServerRole(enum.Enum): @@ -289,16 +289,19 @@ def __init__(self, role: llm_datadist.LLMRole, local_rank: int, self.role, self.cluster_id) def prepare_data_dist(self): - # TODO: The maximum size of the mbuf for the llm datadist. We need to - # find an appropriate value to minimize memory waste. - options = { - "llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME, - "ge.flowGraphMemMaxSize": f"{int(2.25*1024*1024*1024):d}", + buff_size = envs_ascend.LLMDATADIST_BUFFSIZE_MB * 1024 * 1024 + llm_config = LLMConfig() + llm_config.ge_options = { + "llm.SyncKvCacheWaitTime": + envs_ascend.LLMDATADIST_SYNC_CACHE_WAIT_TIME, + "ge.flowGraphMemMaxSize": f"{buff_size:d}", "ge.exec.deviceId": str(self.local_rank), } + llm_config.buf_pool_cfg = '{"buf_cfg": [{"total_size":2097152,"blk_size":256,"max_buf_size":256}]}' if self.role == llm_datadist.LLMRole.PROMPT: - options["llm.listenIpInfo"] = f"{self.local_device_ip}:26000" - self.datadist_engine.init(options) + llm_config.listen_ip_info = f"{self.local_device_ip}:26000" + engine_options = llm_config.generate_options() + self.datadist_engine.init(engine_options) logger.info("llm_datadist init done") self.kv_transfer = self.datadist_engine.kv_cache_manager @@ -869,7 +872,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> int: + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -887,15 +890,15 @@ def get_num_new_matched_tokens( # the block granularity. And it expects the returned blocks and # num_computed_tokens to also be aligned with the block granularity. - # NOTE: only request in waiting queue will come here. we use datadist + # NOTE: only requests in waiting queue will come here. we use datadist # pull cache to do transfer, so we don't align to block_size in prefill, # we won't have extra new matched tokens; in decode, new request kv # cache will be transferred from prefill, so num_computed_tokens = 0, # and extra new matched tokens should be len(request.prompt_token_ids) - # 1 if self.kv_role == llm_datadist.LLMRole.PROMPT: - return 0 - return len(request.prompt_token_ids) - 1 + return 0, False + return len(request.prompt_token_ids) - 1, False def update_state_after_alloc(self, request: "Request", num_external_tokens: int): diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index ec97ad7f253..e158a36f05f 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -96,6 +96,15 @@ # 5000ms. "LLMDATADIST_SYNC_CACHE_WAIT_TIME": lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"), + # The path to the llmdatadist global rank table. If not set, the default + # value is None. This results in an error if the global rank table is + # required but not specified. + "LLMDATADIST_GLOBAL_RANKTABLE": + lambda: os.getenv("LLMDATADIST_GLOBAL_RANKTABLE", None), + # The buffer size in MB for llmdatadist communication. If not set, the + # default value is 2560 MB. + "LLMDATADIST_BUFFSIZE_MB": + lambda: int(os.getenv("LLMDATADIST_BUFFSIZE_MB", 2560)), # The version of vllm is installed. This value is used for developers who # installed vllm from source locally. In this case, the version of vllm is # usually changed. For example, if the version of vllm is "0.9.0", but when @@ -133,8 +142,6 @@ # value to False to disable the optimized model. "USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))), - "GLOBAL_RANKTABLE": - lambda: os.getenv("GLOBAL_RANKTABLE", None) } # end-env-vars-definition From 96bc809a28ca75098209f2be9ee2cb504d1ed685 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Sun, 15 Jun 2025 19:34:27 +0800 Subject: [PATCH 20/21] works Signed-off-by: Jade Zheng --- .../disagg_prefill_proxy_server.py | 13 +++---- .../distributed/llmdatadist_connector_v1.py | 12 ++++-- vllm_ascend/worker/model_runner_v1.py | 37 +++++++++++++++---- 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py b/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py index fadf1f0387d..e2c7706a6ef 100644 --- a/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py +++ b/examples/disaggregated-prefill-v1/disagg_prefill_proxy_server.py @@ -2,6 +2,7 @@ import json import os +import uuid import aiohttp from quart import Quart, make_response, request # type: ignore @@ -32,7 +33,9 @@ async def forward_request(url, data, headers: dict): async def handle_request(): try: original_request_data = await request.get_json() - print(f"{request.headers.get('X-Request-ID')=}") + if isinstance(original_request_data["prompt"], str): + original_request_data["prompt"] = [original_request_data["prompt"]] + request_id = request.headers.get('X-Request-ID', uuid.uuid4().hex) prefill_request = original_request_data.copy() # Change max_tokens = 1 to let it only do prefill @@ -42,9 +45,7 @@ async def handle_request(): async for prefill_result in forward_request( f"http://{PREFILL_ENDPOINT}/v1/completions", prefill_request, - headers={ - "X-Request-ID": request.headers.get("X-Request-ID"), - }, + headers={"X-Request-ID": request_id}, ): # Print the prefill result print("===== Prefill result =====") @@ -62,9 +63,7 @@ async def handle_request(): generator = forward_request( f"http://{DECODE_ENDPOINT}/v1/completions", decode_request, - headers={ - "X-Request-ID": request.headers.get("X-Request-ID"), - }, + headers={"X-Request-ID": request_id}, ) response = await make_response(generator) response.timeout = None diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index 296375e67c6..b8aa06baddb 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -1,6 +1,7 @@ import enum import hashlib import json +import os import random import struct import time @@ -16,6 +17,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.forward_context import get_forward_context from vllm.logger import logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: @@ -116,6 +118,9 @@ def parse_server_group(group, role: ServerRole, devices=device_infos)) return server_infos + if ranktable_path is None or not os.path.exists(ranktable_path): + raise FileNotFoundError(f"Rank table file not found: {ranktable_path}") + with open(ranktable_path, "r") as file: rank_table = json.load(file) @@ -433,8 +438,6 @@ def __init__(self, vllm_config: "VllmConfig", self.tp_rank = get_tensor_model_parallel_rank() self.num_layers = self._vllm_config.model_config.get_num_layers( self._vllm_config.parallel_config) - if self.tp_size == 1: - local_rank = self.dp_rank local_rank = get_world_group().local_rank self.llm_datadist_engine = KVTransferEngine(self.kv_role, local_rank, self.dp_rank, self.tp_rank, @@ -796,7 +799,9 @@ def wait_for_save(self): or self.cluster_info.prefill_dp_size != 1): report_prefill_info(self.cluster_info.router_endpoint, prefill_info_input) - logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) + + if self.tp_rank == 0: + logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) def _inject_kv_into_layer( self, @@ -901,6 +906,7 @@ def get_num_new_matched_tokens( return len(request.prompt_token_ids) - 1, False def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c4512929298..3d8c90c87e3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -39,8 +39,9 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_dp_group, get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE @@ -591,7 +592,7 @@ def _make_attention_mask(self, seq_lens, query_lens, position, seq_lens, query_lens, position, self.dtype, self.device) # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens, default=0) + max_seq_len = 128 # max(seq_lens, default=0) return self.attn_mask_builder.get_attn_mask( max_seq_len, self.dtype, self.device) # Prefill with cache hit. @@ -995,6 +996,8 @@ def _process_reqs( with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): + self.maybe_setup_kv_connector(scheduler_output) + with ProfileExecuteDuration().capture_async("forward"): model_kwargs = {} if self.torchair_graph_enabled: @@ -1020,6 +1023,8 @@ def _process_reqs( **model_kwargs, ) + self.maybe_wait_for_kv_save() + use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1211,11 +1216,6 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - # Update KVConnector with the KVConnector metadata forward(). - if has_kv_transfer_group(): - get_kv_transfer_group().bind_connector_metadata( - scheduler_output.kv_connector_metadata) - with ProfileExecuteDuration().capture_async( "prepare input and forward"): self._update_states(scheduler_output) @@ -1534,7 +1534,7 @@ def profile_run(self) -> None: # TODO: call maybe_profile_with_lora() # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens) + hidden_states = self._dummy_run(num_tokens) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] @@ -1913,3 +1913,24 @@ def select_torchair_padded_batch_size(self, batch_size: int): if batch_size <= padded_batch_size < selected_batch_size: selected_batch_size = padded_batch_size return selected_batch_size + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() From 3f1b0337f2b72506405a082c6fa16f9910ebcad4 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Sun, 15 Jun 2025 19:39:33 +0800 Subject: [PATCH 21/21] chore: lint Signed-off-by: Jade Zheng --- vllm_ascend/distributed/llmdatadist_connector_v1.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/distributed/llmdatadist_connector_v1.py b/vllm_ascend/distributed/llmdatadist_connector_v1.py index b8aa06baddb..1786ae7a072 100644 --- a/vllm_ascend/distributed/llmdatadist_connector_v1.py +++ b/vllm_ascend/distributed/llmdatadist_connector_v1.py @@ -801,7 +801,8 @@ def wait_for_save(self): prefill_info_input) if self.tp_rank == 0: - logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) + logger.info("[rank%d][P]: KV send DONE.", + torch.distributed.get_rank()) def _inject_kv_into_layer( self,