Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/kv_connectors.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
lmcache
nixl >= 0.5.1 # Required for disaggregated prefill
nixl >= 0.6.0 # Required for disaggregated prefill
65 changes: 54 additions & 11 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ def clear_kv_transfer():
ensure_kv_transfer_shutdown()


def get_default_xfer_telemetry(xferDurationS: float = 1,
postDurationS: float = 1,
totalBytes: int = 1,
descCount: int = 1) -> dict:

class AttributeDict(dict):
__slots__ = ()
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__ # type: ignore[assignment]

# We can't instantiate nixlXferTelemetry because it's read only and
# ray env does not have NIXL, so we must fake it
return AttributeDict(
xferDuration=xferDurationS * 1e6, # in us
postDuration=postDurationS * 1e6, # in us
totalBytes=totalBytes,
descCount=descCount,
)


class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.

Expand Down Expand Up @@ -132,6 +152,9 @@ def make_prepped_xfer(self,
def transfer(self, handle: int) -> str:
return "PROC"

def get_xfer_telemetry(self, handle: int) -> dict:
return get_default_xfer_telemetry()

############################################################
# Follow are for changing the behavior during testing.
############################################################
Expand Down Expand Up @@ -169,6 +192,11 @@ def _make_fake_nixl_pkg():
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)

# Mock nixlXferTelemetry class
pkg_root2 = os.path.join(td, "nixl", "_bindings")
os.makedirs(pkg_root2, exist_ok=True)
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
f.write("class nixlXferTelemetry: pass")
# touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
yield td
Expand Down Expand Up @@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init):

# Verify stats values are recorded
assert not stats_after_transfer.is_empty()
assert stats_after_transfer.data["num_successful_transfers"] == 1
assert stats_after_transfer.num_successful_transfers == 1

# Verify stats are reset after retrieval
stats_after_reset = connector.get_kv_connector_stats()
Expand All @@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation():

# Record different transfers on each worker
# Worker 1: 2 transfers
worker1_stats.record_transfer()
worker1_stats.record_transfer()
stats = get_default_xfer_telemetry()
worker1_stats.record_transfer(stats)
worker1_stats.record_transfer(stats)

# Worker 2: 1 transfer
worker2_stats.record_transfer()
worker2_stats.record_transfer(stats)

# Worker 3: 3 transfers
worker3_stats.record_transfer()
worker3_stats.record_transfer()
worker3_stats.record_transfer()
stats = get_default_xfer_telemetry(xferDurationS=2,
postDurationS=2,
totalBytes=2,
descCount=2)
worker3_stats.record_transfer(stats)
worker3_stats.record_transfer(stats)
worker3_stats.record_transfer(stats)

# Create ModelRunnerOutput instances for each worker
worker_outputs = []
Expand Down Expand Up @@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation():
aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
# Number of total transfers across all workers.
assert kv_connector_stats.data["num_successful_transfers"] == 6
assert kv_connector_stats.num_successful_transfers == 6
# Logging proc, call reduce() to get CLI-friendly stats.
cli_stats = kv_connector_stats.reduce()
assert cli_stats["Avg xfer time (ms)"] == 1500.0
assert cli_stats["Avg post time (ms)"] == 1500.0
assert cli_stats["Avg number of descriptors"] == 1.5


def test_multi_kv_connector_stats_aggregation():
Expand All @@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation():

from dataclasses import dataclass

# Mock a KVConnectorStats class for testing aggregation over connectors.
@dataclass
class FooKVConnectorStats(KVConnectorStats):

Expand Down Expand Up @@ -676,7 +715,7 @@ def make_multi_stats(nixl_count: int,
if nixl_count > 0:
nixl_stats = NixlKVConnectorStats()
for _ in range(nixl_count):
nixl_stats.record_transfer()
nixl_stats.record_transfer(get_default_xfer_telemetry())
data["NixlConnector"] = nixl_stats
if foo_count > 0:
foo_stats = FooKVConnectorStats()
Expand Down Expand Up @@ -712,8 +751,10 @@ def make_multi_stats(nixl_count: int,
assert isinstance(kv_connector_stats, MultiKVConnectorStats)

# Validate per-connector totals across workers
assert kv_connector_stats["NixlConnector"].data[
"num_successful_transfers"] == 5
assert isinstance(kv_connector_stats["NixlConnector"],
NixlKVConnectorStats)
assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5
assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats)
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6


Expand Down Expand Up @@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
"working_dir": working_dir, # ship fake nixl package
"env_vars": {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
# TODO: for ray to carry over, remove once we set
"NIXL_TELEMETRY_ENABLE": "1",
},
}
ray.init(runtime_env=runtime_env)
Expand Down
82 changes: 69 additions & 13 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import logging
import math
import os
import queue
import threading
import time
Expand Down Expand Up @@ -54,10 +55,12 @@
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
from nixl._bindings import nixlXferTelemetry
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
nixlXferTelemetry = None

try:
from nixl._api import nixl_agent_config
Expand Down Expand Up @@ -476,6 +479,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.nixl_backends = \
vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"])
# TODO temporary, once nixl allows for telemetry flag in config
# (next release), we can remove this env var.
os.environ["NIXL_TELEMETRY_ENABLE"] = "1"
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
if nixl_agent_config is None:
Expand Down Expand Up @@ -1175,9 +1181,10 @@ def _pop_done_transfers(
for handle, _xfer_stime in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# Get telemetry from NIXL
res = self.nixl_wrapper.get_xfer_telemetry(handle)
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
# TODO (NickLucche) Get from NIXL telemetry once integrated
self.xfer_stats.record_transfer()
elif xfer_state == "PROC":
in_progress = True
continue
Expand Down Expand Up @@ -1449,32 +1456,81 @@ class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""

def __post_init__(self):
if "num_successful_transfers" not in self.data:
self.data["num_successful_transfers"] = 0
if not self.data:
# Empty container init, no data is passed in.
self.reset()

def reset(self):
self.data = {"num_successful_transfers": 0}
# Must be serializable
self.data: dict[str, list[float]] = {
"transfer_duration": [],
"post_duration": [],
"bytes_transferred": [],
"num_descriptors": [],
}

def record_transfer(self):
# TODO: record actual transfer stats when available
self.data["num_successful_transfers"] += 1
def record_transfer(self, res: nixlXferTelemetry):
# Keep metrics units consistent with rest of the code: time us->s
self.data["transfer_duration"].append(res.xferDuration / 1e6)
self.data["post_duration"].append(res.postDuration / 1e6)
self.data["bytes_transferred"].append(res.totalBytes)
self.data["num_descriptors"].append(res.descCount)

def clone_and_reset(self) -> "NixlKVConnectorStats":
old = copy.copy(self)
self.reset()
return old

def is_empty(self) -> bool:
return self.data["num_successful_transfers"] == 0
return self.num_successful_transfers == 0

def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
self.data["num_successful_transfers"] += other.data[
"num_successful_transfers"]
for k, v in other.data.items():
accumulator = self.data[k]
assert isinstance(accumulator, list)
accumulator.extend(v)
return self

def reduce(self) -> dict[str, Union[int, float]]:
# TODO: reduce stats to a single value, calculate latency/throughput
# Compute compact representative stats suitable for CLI logging
if self.is_empty():
return {
"Num successful transfers": 0,
"Avg xfer time (ms)": 0,
"P90 xfer time (ms)": 0,
"Avg post time (ms)": 0,
"P90 post time (ms)": 0,
"Avg MB per transfer": 0,
"Throughput (MB/s)": 0,
"Avg number of descriptors": 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This string template thingy is repeated twice in the same function. Very minor nit

}

xfer_time = np.asarray(self.data["transfer_duration"])
post_time = np.asarray(self.data["post_duration"])
# Convert to MB for CLI logging.
mb = np.asarray(self.data["bytes_transferred"]) / 2**20
descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32)
n = len(descs)
assert n == self.num_successful_transfers

total_mb = mb.sum()
avg_mb = total_mb / n

total_time_seconds = xfer_time.sum()
throughput_mb_s = total_mb / total_time_seconds

return {
"num_successful_transfers": self.data["num_successful_transfers"]
"Num successful transfers": n,
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90) * 1e3, 3),
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
"P90 post time (ms)": round(np.percentile(post_time, 90) * 1e3, 3),
"Avg MB per transfer": round(avg_mb, 3),
"Throughput (MB/s)": round(throughput_mb_s, 3),
"Avg number of descriptors": round(descs.mean(), 1),
}

@property
def num_successful_transfers(self) -> int:
return len(self.data["transfer_duration"])
6 changes: 3 additions & 3 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging()
kv_tranfer_config = self.vllm_config.kv_transfer_config
self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config)
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0

Expand Down Expand Up @@ -101,7 +101,7 @@ def record(self,
self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats:
self.kv_transfer_logging.observe(kv_connector_stats)
self.kv_connector_logging.observe(kv_connector_stats)
self.last_scheduler_stats = scheduler_stats

def log(self):
Expand Down Expand Up @@ -140,7 +140,7 @@ def log(self):
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_transfer_logging.log(log_fn=log_fn)
self.kv_connector_logging.log(log_fn=log_fn)

def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:
Expand Down