Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ jobs:
pytest -sv --durations=0 tests/e2e/singlecard/pooling/
pytest -sv --durations=0 tests/e2e/singlecard/compile/test_norm_quant_fusion.py
pytest -sv --durations=0 tests/e2e/singlecard/test_multistream_overlap_shared_expert.py
pytest -sv --durations=0 tests/e2e/singlecard/test_cpu_offloading.py

# ------------------------------------ v1 spec decode test ------------------------------------ #
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
Expand Down
178 changes: 178 additions & 0 deletions tests/e2e/singlecard/test_cpu_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import socket
import time

import msgspec
import msgspec.msgpack
import zmq
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch


class MockSubscriber:
"""Helper class to receive and verify published events"""

def __init__(
self,
endpoint: str,
topic: str,
):
self.ctx = zmq.Context.instance() # type: ignore
self.topic_bytes = topic.encode("utf-8")

# Set up subscriber socket
self.sub = self.ctx.socket(zmq.SUB) # type: ignore
self.sub.setsockopt(zmq.SUBSCRIBE, self.topic_bytes) # type: ignore
self.sub.connect(endpoint)

self.decoder = msgspec.msgpack.Decoder(type=KVEventBatch)

def get_new_cpu_stored_events(self) -> list[BlockStored]:
cpu_stored_events: list[BlockStored] = []

poller = zmq.Poller() # type: ignore
poller.register(self.sub, zmq.POLLIN) # type: ignore

timeout = 1000 # 1 second
while True:
events = dict(poller.poll(timeout))

if events.get(self.sub) != zmq.POLLIN: # type: ignore
return cpu_stored_events

topic_bytes, _, payload = self.sub.recv_multipart()

assert topic_bytes == self.topic_bytes

event_batch = self.decoder.decode(payload)
assert isinstance(event_batch, KVEventBatch)
for event in event_batch.events:
if isinstance(event, BlockStored) and event.medium == "CPU":
cpu_stored_events.append(event)
timeout = 100

def close(self):
"""Clean up resources"""
self.sub.close()


def _latency_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1)

num_times_cpu_better_than_cold = 0
num_tests = 10
total_cold_time = 0.0
total_gpu_hit_time = 0.0
total_cpu_hit_time = 0.0
prompt_token_ids = [0] * 10001
for i in range(num_tests):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]

# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
total_cold_time += cold_time

# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time

# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()

assert subscriber.get_new_cpu_stored_events()

# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
total_cpu_hit_time += cpu_hit_time

if cpu_hit_time < cold_time:
num_times_cpu_better_than_cold += 1

print("Average times:")
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")

assert num_times_cpu_better_than_cold >= 0.8 * num_tests


def _accuracy_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1)
cpu_block_size = (llm.llm_engine.vllm_config.kv_transfer_config.
kv_connector_extra_config["block_size"])

subscriber.get_new_cpu_stored_events()

# prepend prompt to be cpu block aligned
prompt = "Let's count to 10. One, two, three, four,"
while (len(llm.generate(prompt, use_tqdm=False)[0].prompt_token_ids) %
cpu_block_size != 0):
prompt = ". " + prompt

assert subscriber.get_new_cpu_stored_events()

test_count = 100
success_count = 0
for i in range(test_count):
if (llm.generate(prompt, sampling_params,
use_tqdm=False)[0].outputs[0].text == " five"):
success_count += 1

assert success_count >= 0.5 * test_count


def test_cpu_offloading() -> None:
"""
Tests OffloadingConnector with CPUOffloadingSpec.
"""

# configure OffloadingConnector (spec_name=CPUOffloadingSpec by default)
kv_transfer_config = KVTransferConfig(
kv_connector="OffloadingConnector",
kv_role="kv_both",
kv_connector_extra_config={
"num_cpu_blocks": 1000,
"block_size": 128,
"spec_name": "NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"
},
)

port: int
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("0.0.0.0", 0))
port = s.getsockname()[1]

events_endpoint = f"tcp://*:{port}"
kv_events_config = KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=events_endpoint,
topic="test",
)

llm = LLM(
model="Qwen/Qwen3-0.6B",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)

events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)

try:
_latency_test(llm, subscriber)
_accuracy_test(llm, subscriber)
finally:
subscriber.close()
del llm
17 changes: 5 additions & 12 deletions vllm_ascend/kv_offload/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Optional

import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
Expand Down Expand Up @@ -45,19 +45,12 @@ def get_manager(self) -> OffloadingManager:
return self._manager

def get_handlers(
self, kv_caches: dict[str, torch.Tensor]
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
OffloadingHandler]]:
if not self._handler:
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
for layer_name in layer_names
}

self._handler = CpuNpuOffloadingHandler(
attn_backends=attn_backends,
Comment thread
whx-sjtu marked this conversation as resolved.
gpu_block_size=self.gpu_block_size,
Expand Down
Loading