Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import torch

from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVTransferAggregatedStats

Check failure on line 39 in vllm/distributed/kv_transfer/kv_connector/v1/base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/base.py:39:81: E501 Line too long (90 > 80)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput

Expand Down Expand Up @@ -79,6 +80,12 @@
def role(self) -> KVConnectorRole:
return self._role

def get_transfer_stats(self) -> Optional[KVTransferAggregatedStats]:
"""
Get the transfer stats.
"""
return None

# ==============================
# Worker-side methods
# ==============================
Expand Down Expand Up @@ -187,7 +194,7 @@

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
) -> tuple[Optional[set[str]], Optional[set[str]], Optional[KVTransferAggregatedStats]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Expand All @@ -199,7 +206,7 @@
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
return None, None, None

# ==============================
# Scheduler-side methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVTransferAggregatedStats

Check failure on line 11 in vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py:11:81: E501 Line too long (90 > 80)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput

Expand Down Expand Up @@ -89,7 +90,7 @@

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
) -> tuple[Optional[set[str]], Optional[set[str]], Optional[KVTransferAggregatedStats]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Expand All @@ -101,7 +102,7 @@
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return self._lmcache_engine.get_finished(finished_req_ids)
return self._lmcache_engine.get_finished(finished_req_ids), None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return value of this function does not match the updated type hint. self._lmcache_engine.get_finished(finished_req_ids) returns a tuple of two elements, so this function is currently returning ((elem1, elem2), None), which will cause a TypeError at runtime. The new signature is tuple[Optional[set[str]], Optional[set[str]], Optional[KVTransferAggregatedStats]].

You should unpack the tuple returned by _lmcache_engine.get_finished and add None as the third element for the stats.

Suggested change
return self._lmcache_engine.get_finished(finished_req_ids), None
return *self._lmcache_engine.get_finished(finished_req_ids), None


# ==============================
# Scheduler-side methods
Expand Down
131 changes: 131 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from dataclasses import dataclass, field

from vllm.logger import init_logger

logger = init_logger(__name__)

@dataclass
class KVTransferAggregatedStats:
"""Container for aggregating performance metrics across engines"""
avg_transfer_durations: float = 0.0
avg_bytes_transferred: float = 0.0
num_blocks_transferred: int = 0
num_successful_transfers: int = 0

def aggregate(self, other: "KVTransferAggregatedStats"):
if other.is_empty():
return

# Reduce stats
self.num_successful_transfers += other.num_successful_transfers

def is_empty(self) -> bool:
return self.num_successful_transfers == 0

@dataclass
class KVTransferStats:
"""Container for transfer performance metrics"""
transfer_durations: list[float]=field(default_factory=list) # Transfer durations in seconds
bytes_transferred: list[int]=field(default_factory=list) # Bytes transferred per transfer
num_blocks_transferred: list[int]=field(default_factory=list) # Number of blocks per transfer
num_transfers: int = 0

def reset(self):
self.transfer_durations = []
self.bytes_transferred = []
self.num_blocks_transferred = []

def observe(self):
# TODO finish this
self.num_transfers += 1

def reduce_and_reset(self) -> KVTransferAggregatedStats:
# NOTE (NickLucche): to have statistical significance, we assume the
# size of the measurements groups to be the same. This allows to bound
# the size of the messages.
# TODO finish this
stats = KVTransferAggregatedStats(
avg_transfer_durations=11.0,
avg_bytes_transferred=0.0,
num_blocks_transferred=0,
num_successful_transfers=self.num_transfers)
self.num_transfers = 0
return stats
Comment on lines +42 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reduce_and_reset method is incomplete and returns hardcoded values (e.g., avg_transfer_durations=11.0). This will lead to incorrect metrics being reported. The method should compute the aggregated stats from the collected lists (transfer_durations, bytes_transferred, etc.) before resetting them.

The TODO comment indicates this is a work in progress, but as it stands, it's a correctness bug.


def record_transfer(self, duration: float, bytes_count: int, num_blocks: int):
self.transfer_durations.append(duration)
self.bytes_transferred.append(bytes_count)
self.num_blocks_transferred.append(num_blocks)

def get_throughput_stats(self, now: float) -> tuple[float, float, float]:
"""Get transfer throughput statistics"""
pass

def get_latency_stats(self) -> tuple[float, float, float]:
"""Get transfer latency statistics"""
# TODO possible use
import numpy as np
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The numpy import is inside the get_latency_stats method. It's generally better practice to have all imports at the top of the file for clarity and to avoid repeated import overhead if the method is called multiple times. Please move import numpy as np to the top of the file.

durations = np.array(self.transfer_durations)
avg_latency = float(np.mean(durations))
p50_latency = float(np.percentile(durations, 50))
p95_latency = float(np.percentile(durations, 95))

return avg_latency, p50_latency, p95_latency

class KVTransferLogging:
def __init__(self):
self.reset()
self.transfer_stats = None

def reset(self):
self.transfer_durations = []
self.bytes_transferred = []
self.num_blocks_transferred: int = 0

def observe(self, transfer_stats: KVTransferAggregatedStats):
self.transfer_stats = transfer_stats
# self.transfer_durations.append(transfer_stats.transfer_durations)
# self.bytes_transferred.append(transfer_stats.bytes_transferred)
# self.num_blocks_transferred += transfer_stats.num_blocks_transferred

def log(self, log_fn=logger.info):
"""Log transfer metrics periodically, similar to throughput logging"""
# Only log if we have transfer data
if self.transfer_stats is not None:
log_fn("KV Transfer metrics: %s", self.transfer_stats)
# bytes_per_sec, blocks_per_sec, transfers_per_sec = \
# self.transfer_metrics.get_throughput_stats(now)

# # Get latency stats
# avg_latency, p50_latency, p95_latency = \
# self.transfer_metrics.get_latency_stats()

# # Format throughput for readability
# if bytes_per_sec >= 1024**3: # GB/s
# bytes_throughput_str = f"{bytes_per_sec / (1024**3):.2f} GB/s"
# elif bytes_per_sec >= 1024**2: # MB/s
# bytes_throughput_str = f"{bytes_per_sec / (1024**2):.1f} MB/s"
# elif bytes_per_sec >= 1024: # KB/s
# bytes_throughput_str = f"{bytes_per_sec / 1024:.1f} KB/s"
# else: # B/s
# bytes_throughput_str = f"{bytes_per_sec:.1f} B/s"

# # Log the metrics in a format similar to the existing throughput logs
# logger.info(
# "Engine %s: KV Transfer metrics: "
# "Avg transfer throughput: %s, "
# "Blocks/s: %.1f, Transfers/s: %.1f, "
# "Avg latency: %.3fs, P50: %.3fs, P95: %.3fs, "
# "Total transfers: %d",
# self.engine_id,
# bytes_throughput_str,

Check failure on line 121 in vllm/distributed/kv_transfer/kv_connector/v1/metrics.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/metrics.py:121:81: E501 Line too long (83 > 80)
# blocks_per_sec,
# transfers_per_sec,
# avg_latency,
# p50_latency,
# p95_latency,
# len(self.transfer_metrics.transfer_durations)
# )

# # Reset metrics for next interval
# self.transfer_metrics.reset(now)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVTransferAggregatedStats

Check failure on line 14 in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py:14:81: E501 Line too long (90 > 80)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -104,11 +105,14 @@

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
) -> tuple[Optional[set[str]], Optional[set[str]], Optional[KVTransferAggregatedStats]]:
finished_sending: set[str] = set()
finished_recving: set[str] = set()
base_xfer_stats = KVTransferAggregatedStats()
for c in self._connectors:
sending, recving = c.get_finished(finished_req_ids)
sending, recving, xfer_stats = c.get_finished(finished_req_ids)
if xfer_stats is not None:
base_xfer_stats.aggregate(xfer_stats)
if not recving and not sending:
continue
# Aggregate finished recving request ids.
Expand All @@ -127,9 +131,10 @@
else:
self._extra_async_saves[req_id] = extra_pending - 1

return finished_sending or None, finished_recving or None
base_xfer_stats = None if base_xfer_stats.is_empty() else base_xfer_stats
return finished_sending or None, finished_recving or None, base_xfer_stats

# ==============================

Check failure on line 137 in vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py:137:81: E501 Line too long (82 > 80)
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
Expand Down
63 changes: 50 additions & 13 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVTransferAggregatedStats, KVTransferStats

Check failure on line 24 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:24:81: E501 Line too long (107 > 80)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
Expand All @@ -37,7 +38,7 @@
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request

Transfer = tuple[int, float] # (xfer_handle, start_time)
Transfer = tuple[int, float, int] # (xfer_handle, start_time, num_blocks)
EngineId = str
ReqId = str
GET_META_MSG = b"get_meta_msg"
Expand Down Expand Up @@ -179,6 +180,11 @@
"""NixlConnector does not save explicitly."""
pass


def get_transfer_stats(self) -> Optional[KVTransferAggregatedStats]:
assert self.connector_worker is not None
return self.connector_worker.xfer_stats_aggregated


class NixlConnectorScheduler:
"""Implementation of Scheduler side methods"""
Expand Down Expand Up @@ -441,6 +447,10 @@
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)

# Transfer metrics tracking.
self.xfer_stats = KVTransferStats()
self.xfer_stats_aggregated = KVTransferAggregatedStats()

def __del__(self):
"""Cleanup background threads on destruction."""
Expand Down Expand Up @@ -825,9 +835,12 @@
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving", self.tp_rank,
len(done_sending), len(done_recving))


# Aggregate transfer stats for this rank.
xfer_stats = self.xfer_stats.reduce_and_reset()
self.xfer_stats_aggregated.aggregate(xfer_stats)
Comment thread
NickLucche marked this conversation as resolved.
if self.world_size == 1:
return done_sending, done_recving
return done_sending, done_recving, self.xfer_stats_aggregated

# Rank 0: get finished from all other ranks.
if self.tp_rank == 0:
Expand All @@ -839,8 +852,12 @@
# Keep track of how many other ranks have finished.
other_ranks_finished_ids: list[str] = []
for i in range(1, self.world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
finished_req_ids, xfer_stats = self.tp_group.recv_object(src=i)
other_ranks_finished_ids.extend(finished_req_ids)
# Aggregate transfer stats from all ranks.
self.xfer_stats_aggregated.aggregate(xfer_stats)
# TODO reset after logging or keep global?

for req_id in other_ranks_finished_ids:
if (req_id in self._done_recving_count
or req_id in self._recving_transfers):
Expand All @@ -860,16 +877,16 @@
if self._done_sending_count[req_id] == self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)

Check failure on line 880 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:880:81: E501 Line too long (81 > 80)
return all_done_sending, all_done_recving
return all_done_sending, all_done_recving, self.xfer_stats_aggregated

# Ranks 1 to N-1: send finished ids to Rank 0.
else:
finished_req_ids = list(done_recving.union(done_sending))
self.tp_group.send_object(finished_req_ids, dst=0)
self.tp_group.send_object((finished_req_ids, xfer_stats), dst=0)

# Unused as only Rank 0 results are sent to scheduler.
return done_sending, done_recving
return done_sending, done_recving, self.xfer_stats_aggregated

def _get_new_notifs(self) -> set[str]:
"""
Expand All @@ -890,7 +907,7 @@
return notified_req_ids

def _pop_done_transfers(
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
self, transfers: dict[str, list[Transfer]]) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
Expand All @@ -899,23 +916,43 @@
set of req_ids that have all done xfers
"""
done_req_ids: set[str] = set()
current_time = time.perf_counter()

for req_id, handles in list(transfers.items()):
in_progress = False
for handle, _xfer_stime in handles:
for handle, xfer_stime, num_blocks in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# Calculate transfer metrics

Check failure on line 926 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:926:21: F841 Local variable `transfer_duration` is assigned to but never used
transfer_duration = current_time - xfer_stime

# Calculate bytes transferred based on actual block count

Check failure on line 929 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:929:21: F841 Local variable `bytes_transferred` is assigned to but never used
bytes_transferred = self.block_len * num_blocks

# Record the completed transfer metrics
# self.transfer_stats.record_transfer(
# duration=transfer_duration,
# bytes_count=bytes_transferred,
# num_blocks=num_blocks
# )
# TODO actual observe
self.xfer_stats.observe()
Comment thread
NickLucche marked this conversation as resolved.

self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
in_progress = True
continue
else:
raise RuntimeError("Transfer failed with state %s",
xfer_state)


if not in_progress:
done_req_ids.add(req_id)
del transfers[req_id]
return done_req_ids


def start_load_kv(self, metadata: NixlConnectorMetadata):
"""
Start loading by triggering non-blocking nixl_xfer.
Expand Down Expand Up @@ -1046,9 +1083,9 @@
self.nixl_wrapper.transfer(handle)

# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self._recving_transfers[request_id].append(
(handle, time.perf_counter()))
# Store handle, start_time, and block count for metrics tracking
transfer_info = (handle, time.perf_counter(), len(local_block_ids))
self._recving_transfers[request_id].append(transfer_info)

def _get_block_descs_ids(self,
engine_id: str,
Expand Down
Loading