Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 73 additions & 49 deletions tests/unit_tests/kv_offload/offloading_connector/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any
Expand All @@ -11,31 +10,39 @@

from tests.unit_tests.kv_offload.utils import (
EOS_TOKEN_ID,
create_request_compatible_with_signature,
create_model_runner_output,
create_request_compatible_with_signature,
create_vllm_config,
)
from vllm import SamplingParams
from vllm.config import KVTransferConfig, VllmConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
OffloadingConnectorMetadata, )
OffloadingConnectorMetadata,
OffloadingWorkerMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
OffloadingConnector, )
from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.core.kv_cache_utils import (
BlockHash,
get_request_block_hasher,
init_none_hash,
)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.kv_offload.abstract import (
LoadStoreSpec,
OffloadingManager,
OffloadKey,
PrepareStoreOutput,
make_offload_key,
)
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
Expand All @@ -44,26 +51,28 @@
TransferResult,
TransferSpec,
)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheGroupSpec,
)
from vllm.v1.structured_output import StructuredOutputManager


def to_key(int_hash: int) -> OffloadKey:
return make_offload_key(str(int_hash).encode(), 0)


def to_keys(int_hashes: list[int]) -> list[OffloadKey]:
return [to_key(i) for i in int_hashes]


class MockLoadStoreSpec(LoadStoreSpec):

def __init__(self, block_hashes: Iterable[BlockHash]):
self.block_hashes: list[BlockHash] = list(block_hashes)
def __init__(self, offload_keys: Iterable[OffloadKey]):
self.offload_keys: list[OffloadKey] = list(offload_keys)

@staticmethod
def medium() -> str:
return "Mock"

def __repr__(self) -> str:
return repr(self.block_hashes)
return repr(self.offload_keys)


class MockOffloadingHandler(OffloadingHandler):
Expand Down Expand Up @@ -111,7 +120,8 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):

self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MockOffloadingSpec sets manager.lookup.return_value twice (first to 0, then to False). This looks accidental and makes it unclear what the mocked contract is supposed to be. Please remove the redundant assignment and keep a single return type/value consistent with OffloadingManager.lookup's expected behavior.

Suggested change
self.manager.lookup.return_value = 0

Copilot uses AI. Check for mistakes.
self.manager.prepare_load = lambda block_hashes, req_context: (MockLoadStoreSpec(block_hashes))
self.manager.prepare_load = lambda keys, req_context: MockLoadStoreSpec(keys)
self.manager.lookup.return_value = False
self.handler = MockOffloadingHandler()

def get_manager(self) -> OffloadingManager:
Expand Down Expand Up @@ -143,14 +153,17 @@ class TransferSummary:

class RequestRunner:

def __init__(self,
offloaded_block_size: int,
gpu_block_size: int,
num_gpu_blocks: int,
async_scheduling: bool = False):
def __init__(
self,
offloaded_block_size: int,
gpu_block_size: int,
num_gpu_blocks: int,
async_scheduling: bool = True,
):
self.offloaded_block_size: int = offloaded_block_size
self.gpu_block_size: int = gpu_block_size
self.num_gpu_blocks: int = num_gpu_blocks
self.async_scheduling: bool = async_scheduling

self.req_id: int = -1

Expand Down Expand Up @@ -184,7 +197,8 @@ def __init__(self,
)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
self.num_kv_groups = len(kv_cache_config.kv_cache_groups)
scheduler_cls = (AsyncScheduler if vllm_config.scheduler_config.async_scheduling else Scheduler)

scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
self.scheduler = scheduler_cls(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
Expand Down Expand Up @@ -247,7 +261,11 @@ def __init__(self,
slot_mapping={},
)

def new_request(self, token_ids: list[int]):
def new_request(
self,
token_ids: list[int],
kv_transfer_params: dict | None = None,
):
self.req_id += 1

sampling_params = SamplingParams(max_tokens=1000)
Expand All @@ -260,9 +278,9 @@ def new_request(self, token_ids: list[int]):
"pooling_params": None,
"block_hasher": self._block_hasher,
}

req = create_request_compatible_with_signature(**request_kwargs)

if kv_transfer_params is not None:
req.kv_transfer_params = kv_transfer_params
self.scheduler.add_request(req)

def _parse_transfers(self):
Expand Down Expand Up @@ -294,11 +312,11 @@ def _parse_transfers(self):
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])

# list of (block_hash, sub_block_offset)
# list of (offload_key, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for offload_key in offload_spec.offload_keys:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
offload_addresses.append((offload_key, sub_block_idx))

if store:
assert len(gpu_block_indices) == len(offload_addresses)
Expand Down Expand Up @@ -329,8 +347,15 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool):

tokens_iter = iter(decoded_tokens)
token_id = next(tokens_iter, None)
prev_scheduler_output = None
prev_model_runner_output = None
while True:
assert self.scheduler.requests
# Strict-always-False frees the request immediately on EOS, but
# the worker may still have a deferred store queued. In production
# the next request's step drains it; in single-request tests we
# must keep stepping until the scheduler sees no in-flight jobs.
if not self.scheduler.requests and not self.connector_scheduler._jobs:
break

scheduler_output = self.scheduler.schedule()
self._update_gpu_block_idx()
Expand All @@ -351,6 +376,7 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool):
self.offloading_spec.complete_transfers()

finished_sending, finished_recving = self.worker_connector.get_finished(scheduler_output.finished_req_ids)
worker_meta = self.worker_connector.build_connector_worker_meta() or OffloadingWorkerMetadata()

self.worker_connector.clear_connector_metadata()

Expand All @@ -359,40 +385,38 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool):
finished_sending=finished_sending,
finished_recving=finished_recving,
token_id=token_id or 0,
kv_connector_worker_meta=worker_meta,
)

prev_token_id = token_id
if self.scheduler.running:
token_id = next(tokens_iter, None)

self.scheduler.update_from_output(scheduler_output, model_runner_output)
if self.async_scheduling:
# in async scheduling we update the output of the previous step
if prev_model_runner_output is not None:
self.scheduler.update_from_output(prev_scheduler_output, prev_model_runner_output)
prev_scheduler_output = scheduler_output
prev_model_runner_output = model_runner_output
else:
self.scheduler.update_from_output(scheduler_output, model_runner_output)

if (prev_token_id is EOS_TOKEN_ID and prev_token_id != token_id and self.scheduler.requests):
if (prev_token_id == EOS_TOKEN_ID and prev_token_id != token_id
and (self.scheduler.requests or self.connector_scheduler._jobs)):
# continue for one more step to allow offloading to kick off
continue

if token_id is None:
if self.async_scheduling:
# sample last token
self.scheduler.update_from_output(prev_scheduler_output, prev_model_runner_output)
break

self._parse_transfers()

# run one more step to update finished stored
if EOS_TOKEN_ID in decoded_tokens:
assert not self.scheduler.running

while self.scheduler.requests:
scheduler_output = self.scheduler.schedule()

finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids)

assert not finished_recving

model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(finished_sending=finished_sending)

self.scheduler.update_from_output(scheduler_output, model_runner_output)

def run(
self,
decoded_tokens: list[int],
Expand Down Expand Up @@ -445,7 +469,7 @@ def run(
def request_runner():
runners = []

def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks, async_scheduling=False):
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks, async_scheduling):
runner = RequestRunner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
Expand All @@ -458,10 +482,10 @@ def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks, async_s
yield runner_factory # pass factory to the test


def generate_store_output(block_hashes: Iterable[BlockHash]):
block_hashes = list(block_hashes)
def generate_store_output(keys: Iterable[OffloadKey]):
keys = list(keys)
return PrepareStoreOutput(
keys_to_store=list(block_hashes),
store_spec=MockLoadStoreSpec(block_hashes),
keys_to_store=list(keys),
store_spec=MockLoadStoreSpec(keys),
evicted_keys=[],
)
23 changes: 13 additions & 10 deletions tests/unit_tests/kv_offload/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
KVConnectorWorkerMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa
ExampleConnector, )
Expand Down Expand Up @@ -64,9 +65,9 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0

# KVCache Manager.
assert (len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks) == 0)
assert (len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].num_cached_block) == 0)
num_free_blocks = (scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].num_cached_block) == 0
num_free_blocks = scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)

# NOTE(rob): just the ref count on blocks will be 0. The hash
Expand Down Expand Up @@ -232,6 +233,7 @@ def create_model_runner_output(
invalid_block_ids: set[int] | None = None,
use_eos: bool = False,
token_id: int = 0,
kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""

Expand All @@ -243,12 +245,13 @@ def create_model_runner_output(
sampled_token = EOS_TOKEN_ID if use_eos else token_id
sampled_token_ids = [[sampled_token] for _ in req_ids]

kv_connector_output = (None if (finished_sending is None and finished_recving is None and invalid_block_ids is None)
else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(),
))
kv_connector_output = (None if (finished_sending is None and finished_recving is None and invalid_block_ids is None
and kv_connector_worker_meta is None) else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(),
kv_connector_worker_meta=kv_connector_worker_meta,
))

# Make output data structure.
return ModelRunnerOutput(
Expand All @@ -269,7 +272,7 @@ def __init__(self, config: VllmConfig, role, kv_cache_config):
self._connector = ExampleConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = (tempfile.gettempdir() + f"/connector_{self.name}-{self.role.name}_events.log")
self._event_file = tempfile.gettempdir() + f"/connector_{self.name}-{self.role.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
Expand Down
13 changes: 8 additions & 5 deletions tests/unit_tests/ops/test_hpu_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
HPUCompressedTensorsW8A8Int8_BF16Fallback,
HPUCompressedTensorsW8A8Fp8MoEMethod)
from vllm_gaudi.utils import HPUCompileConfig
from vllm.forward_context import override_forward_context
from vllm.forward_context import ForwardContext, override_forward_context
from safetensors import safe_open


Expand Down Expand Up @@ -387,10 +387,13 @@ def test_compressed_tensors_wna16_moe_method(default_vllm_config: None, dist_ini
ref_output = f.get_tensor("ref_output")

# Execute layer
mock_ctx = MagicMock(spec=["dp_metadata"])
mock_ctx.dp_metadata = None
with override_forward_context(mock_ctx):
out = oot_op.runner._forward_dispatch(oot_op, hidden_states, router_logits, hidden_states)
ctx = ForwardContext(
no_compile_layers={oot_op.runner.layer_name: oot_op},
attn_metadata={},
slot_mapping={},
)
with override_forward_context(ctx):
out = oot_op.runner.forward(hidden_states, router_logits)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-4, rtol=1e-4)
Expand Down
14 changes: 8 additions & 6 deletions tests/unit_tests/ops/test_hpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import torch
import habana_frameworks.torch as htorch
from utils import get_data_path, create_fused_moe
from unittest.mock import MagicMock
from vllm_gaudi.ops.hpu_fused_moe import HPUUnquantizedFusedMoEMethod
from vllm_gaudi.utils import HPUCompileConfig
from vllm.forward_context import override_forward_context
from vllm.forward_context import ForwardContext, override_forward_context
from safetensors import safe_open


Expand Down Expand Up @@ -38,10 +37,13 @@ def test_unquantized_fused_moe_method(default_vllm_config: None, dist_init):
ref_output = f.get_tensor("ref_output")

# Execute layer
mock_ctx = MagicMock(spec=["dp_metadata"])
mock_ctx.dp_metadata = None
with override_forward_context(mock_ctx):
out = oot_op.runner._forward_dispatch(oot_op, hidden_states, router_logits, hidden_states)
ctx = ForwardContext(
no_compile_layers={oot_op.runner.layer_name: oot_op},
attn_metadata={},
slot_mapping={},
)
with override_forward_context(ctx):
out = oot_op.runner.forward(hidden_states, router_logits)

# Check correctness
torch.testing.assert_close(ref_output, out, atol=1e-4, rtol=1e-4)
Loading
Loading