diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 080079afccb7..0cfa64b2b5d8 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional import numpy as np import numpy.typing as npt @@ -40,6 +40,10 @@ class KVArgs: prefill_start_layer: int # for system dp system_dp_rank: int + # Optional tensor buffer references for CPU buffer KV transfer + k_buffers: Optional[List[Any]] = None # List[torch.Tensor], one per layer + v_buffers: Optional[List[Any]] = None # List[torch.Tensor], one per layer + head_dim: Optional[int] = None class KVPoll: diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index b015fd458b74..4571e4962a87 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -410,9 +410,16 @@ def __init__( self.target_tp_ranks = [self.target_tp_rank] elif self.kv_mgr.attn_tp_size > self.prefill_info.attn_tp_size: if not self.kv_mgr.is_mla_backend: - logger.warning_once( - "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " - ) + if getattr(self.kv_mgr, "nixl_use_cpu_buffer", False): + logger.info_once( + "Mixed TP sizes detected (decode_tp > prefill_tp). " + "CPU buffer transfer (--nixl-use-cpu-buffer) is enabled for correct head redistribution." + ) + else: + logger.warning_once( + "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " + "Consider running with --nixl-use-cpu-buffer for correct mixed-TP KV transfer." + ) self.target_tp_rank = ( self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size ) // (self.kv_mgr.attn_tp_size // self.prefill_info.attn_tp_size) @@ -425,9 +432,16 @@ def __init__( self.target_tp_ranks = [self.target_tp_rank] else: if not self.kv_mgr.is_mla_backend: - logger.warning_once( - "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " - ) + if getattr(self.kv_mgr, "nixl_use_cpu_buffer", False): + logger.info_once( + "Mixed TP sizes detected (prefill_tp > decode_tp). " + "CPU buffer transfer (--nixl-use-cpu-buffer) is enabled for correct head redistribution." + ) + else: + logger.warning_once( + "Performance is NOT guaranteed when using different TP sizes for non-MLA models. " + "Consider running with --nixl-use-cpu-buffer for correct mixed-TP KV transfer." + ) # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models; self.target_tp_ranks = [ rank diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 7d8630c5cb66..68af18fba166 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -304,6 +304,15 @@ def _init_kv_manager(self) -> CommonKVManager: kv_args.kv_item_lens = kv_item_lens kv_args.page_size = self.token_to_kv_pool.page_size + # Add tensor buffer references for CPU buffer KV transfer + if getattr(self.scheduler.server_args, "nixl_use_cpu_buffer", False): + if hasattr(self.token_to_kv_pool, "k_buffer") and hasattr( + self.token_to_kv_pool, "v_buffer" + ): + kv_args.k_buffers = self.token_to_kv_pool.k_buffer + kv_args.v_buffers = self.token_to_kv_pool.v_buffer + kv_args.head_dim = self.token_to_kv_pool.head_dim + kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos() ) diff --git a/python/sglang/srt/disaggregation/nixl/__init__.py b/python/sglang/srt/disaggregation/nixl/__init__.py index 4df7baba2dfa..db8cbf62c584 100644 --- a/python/sglang/srt/disaggregation/nixl/__init__.py +++ b/python/sglang/srt/disaggregation/nixl/__init__.py @@ -4,3 +4,4 @@ NixlKVReceiver, NixlKVSender, ) +from sglang.srt.disaggregation.nixl.pinned_buffer_pool import PinnedBufferPool diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 279dfab90b73..dc2a44ad92a3 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -2,15 +2,18 @@ import dataclasses import logging +import os import struct import threading import time import uuid from collections import defaultdict -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple import numpy as np import numpy.typing as npt +import requests +import torch from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll from sglang.srt.disaggregation.common.conn import ( @@ -20,14 +23,70 @@ CommonKVSender, ) from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous +from sglang.srt.disaggregation.nixl.pinned_buffer_pool import PinnedBufferPool from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) +# Default staging buffer size for Triton KV transfer (256MB) +DEFAULT_TRITON_STAGING_BUFFER_SIZE_MB = 256.0 + + +def _import_triton_kv_transfer(): + """Lazily import Triton KV transfer functions to avoid import errors when not used.""" + try: + from sglang.srt.layers.attention.triton_ops.kv_transfer import ( + gather_kv_to_pinned_all_layers, + scatter_kv_with_staging_all_layers, + ) + + return gather_kv_to_pinned_all_layers, scatter_kv_with_staging_all_layers + except ImportError as e: + logger.warning(f"[TRITON-KV] Failed to import Triton KV transfer: {e}") + return None, None + + GUARD = "NixlMsgGuard".encode("ascii") +# Set SGLANG_NIXL_DEBUG_CHECKSUM=1 to enable KV transfer checksum validation. +# Prefill logs a checksum after gather; decode logs one before scatter. +# Mismatches indicate corruption between gather and scatter. +_NIXL_DEBUG_CHECKSUM = os.environ.get("SGLANG_NIXL_DEBUG_CHECKSUM", "0") == "1" + + +def _kv_checksum(buf: torch.Tensor, label: str, room: int) -> int: + """ + Compute and log a diagnostic checksum of a pinned CPU KV buffer. + + Views the buffer as int16, samples ~1024 evenly-spaced elements, and sums + them. Returns the 32-bit wrapped sum. Logs at WARNING level so it is + visible without changing the log level. + + Detects: + - All-zero buffer → transfer never wrote / wrong address + - Value mismatch → data written to wrong slot / premature pool release + - Count of zeros → partially-written transfer + """ + flat = buf.view(torch.int16) + n = flat.numel() + step = max(1, n // 1024) + sample = flat[::step].to(torch.int32) + total = int(sample.sum().item()) + checksum = total & 0xFFFFFFFF + num_zeros = int((sample == 0).sum().item()) + # Log first 4 and last 4 raw int16 values as a sanity peek + head_vals = flat[:4].tolist() + tail_vals = flat[-4:].tolist() + logger.warning( + f"[KV-CKSUM] {label} room={room} " + f"checksum=0x{checksum:08x} " + f"sampled={len(sample)} zeros={num_zeros}/{len(sample)} " + f"head={head_vals} tail={tail_vals}" + ) + return checksum + @dataclasses.dataclass class TransferInfo: @@ -41,6 +100,9 @@ class TransferInfo: dst_aux_index: int required_dst_info_num: int dst_state_indices: List[int] + # Per-request allocated pinned buffer address and size (for concurrent-safe CPU buffer transfers) + dst_pinned_ptr: int = 0 + dst_pinned_size: int = 0 def is_dummy(self): return self.dst_kv_indices.size == 0 @@ -53,6 +115,10 @@ def from_zmq(cls, msg: List[bytes]): else: dst_state_indices = [] + # Parse per-request pinned buffer info from msg[8]/msg[9] if present + dst_pinned_ptr = int(msg[8].decode("ascii")) if len(msg) > 8 else 0 + dst_pinned_size = int(msg[9].decode("ascii")) if len(msg) > 9 else 0 + return cls( room=int(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), @@ -62,6 +128,8 @@ def from_zmq(cls, msg: List[bytes]): dst_aux_index=int(msg[5].decode("ascii")), required_dst_info_num=int(msg[6].decode("ascii")), dst_state_indices=dst_state_indices, + dst_pinned_ptr=dst_pinned_ptr, + dst_pinned_size=dst_pinned_size, ) @@ -81,6 +149,9 @@ class KVArgsRegisterInfo: decode_tp_size: int decode_tp_rank: int dst_kv_item_len: int + # For Triton KV transfer: pinned CPU buffer address and size + dst_pinned_ptr: int = 0 + dst_pinned_size: int = 0 @classmethod def from_zmq(cls, msg: List[bytes]): @@ -90,6 +161,9 @@ def from_zmq(cls, msg: List[bytes]): else: dst_state_data_ptrs = [] + dst_pinned_ptr = int(msg[12].decode("ascii")) + dst_pinned_size = int(msg[13].decode("ascii")) + return cls( room=str(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), @@ -103,6 +177,8 @@ def from_zmq(cls, msg: List[bytes]): decode_tp_size=int(msg[9].decode("ascii")), decode_tp_rank=int(msg[10].decode("ascii")), dst_kv_item_len=int(msg[11].decode("ascii")), + dst_pinned_ptr=dst_pinned_ptr, + dst_pinned_size=dst_pinned_size, ) @@ -110,18 +186,19 @@ def from_zmq(cls, msg: List[bytes]): class TransferStatus: """Used by KV Receiver to know when a transfer is done.""" - # KV chunks received per pp_rank: {pp_rank: set of chunk_ids} - received_kvs_per_pp: Dict[int, Set[int]] = dataclasses.field( + # KV chunks received per sender: {sender_key: set of chunk_ids} + # sender_key is the NIXL peer_name, which uniquely identifies each prefill TP rank + received_kvs_per_sender: Dict[str, Set[int]] = dataclasses.field( default_factory=lambda: defaultdict(set) ) - # Expected chunk count per pp_rank (set when is_last=True): {pp_rank: expected_count} - expected_kvs_per_pp: Dict[int, int] = dataclasses.field(default_factory=dict) - # Number of PP ranks expected to send data. - num_pp_ranks_expected: Optional[int] = None + # Expected chunk count per sender (set when is_last=True): {sender_key: expected_count} + expected_kvs_per_sender: Dict[str, int] = dataclasses.field(default_factory=dict) + # Number of senders expected to send data. + num_senders_expected: Optional[int] = None # Whether aux data has been received. received_aux: bool = False - # PP ranks that have sent state data (state is layer-specific, each PP rank sends its portion). - received_state_per_pp: Set[int] = dataclasses.field(default_factory=set) + # Senders that have sent state data. + received_state_per_sender: Set[str] = dataclasses.field(default_factory=set) # Whether state data is expected (set based on state_type). expects_state: bool = False # Mark as failed @@ -130,20 +207,20 @@ class TransferStatus: def is_done(self): if self.is_failure: return True - if self.num_pp_ranks_expected is None or not self.received_aux: + if self.num_senders_expected is None or not self.received_aux: return False - # If state data is expected, check all PP ranks have sent it + # If state data is expected, check all senders have sent it if ( self.expects_state - and len(self.received_state_per_pp) < self.num_pp_ranks_expected + and len(self.received_state_per_sender) < self.num_senders_expected ): return False - # All PP ranks must have reported their expected count - if len(self.expected_kvs_per_pp) < self.num_pp_ranks_expected: + # All senders must have reported their expected count + if len(self.expected_kvs_per_sender) < self.num_senders_expected: return False - # Each PP rank must have received all expected chunks - for pp_rank, expected in self.expected_kvs_per_pp.items(): - if len(self.received_kvs_per_pp[pp_rank]) != expected: + # Each sender must have received all expected chunks + for sender_key, expected in self.expected_kvs_per_sender.items(): + if len(self.received_kvs_per_sender[sender_key]) != expected: return False return True @@ -184,14 +261,34 @@ def __init__( ) logger.info(f"NIXL KVManager initialized with backend: {backend}") + # Store CPU buffer transfer configuration + self.nixl_use_cpu_buffer = getattr(server_args, "nixl_use_cpu_buffer", False) + self.triton_staging_buffer: Optional[torch.Tensor] = None + self._pinned_pool: Optional[PinnedBufferPool] = None + self.triton_pinned_descs = None + self._server_args = server_args + + # Initialize Triton transfer infrastructure if enabled + if self.nixl_use_cpu_buffer: + self._init_triton_transfer_buffers() + self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: + # Per-request KV index accumulation for CPU buffer chunked-prefill fix. + # Maps bootstrap_room -> list of kv_index arrays (one per send() chunk). + # All chunks are combined into a single gather+NIXL write on is_last, + # preventing each chunk from overwriting the previous chunk's offset in + # the destination (decode) CPU buffer. + self._cpu_pending_kv: Dict[int, List] = {} self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( TransferStatus ) + # Deferred pool releases: (cuda_event, pool_offset) pairs waiting for + # the scatter kernel to finish before the pinned region can be reused. + self._pending_pool_releases: List[Tuple[torch.cuda.Event, int]] = [] self._start_heartbeat_checker_thread() else: raise ValueError( @@ -320,6 +417,49 @@ def register_buffer_to_engine(self): if not self.state_descs: raise Exception("NIXL memory registration failed for state tensors") + # Register shared pinned buffer pool with NIXL if enabled + if self.nixl_use_cpu_buffer and self._pinned_pool is not None: + self.triton_pinned_descs = self._pinned_pool.register_with_nixl(self.agent) + + def _init_triton_transfer_buffers(self): + """Initialize GPU staging buffer and shared pinned buffer pool for Triton KV transfer.""" + # Get dtype from KV cache buffers (supports fp8, fp16, bf16) + k_buffers = self.kv_args.k_buffers + if k_buffers is not None and len(k_buffers) > 0: + kv_dtype = k_buffers[0].dtype + kv_elem_bytes = k_buffers[0].element_size() + else: + # Fallback to bfloat16 if k_buffers not available yet + kv_dtype = torch.bfloat16 + kv_elem_bytes = 2 + logger.warning( + "[TRITON-KV] k_buffers not available, falling back to bfloat16. " + "This may cause issues if KV cache uses a different dtype (e.g., fp8)." + ) + + # Allocate GPU staging buffer (fixed size, 256MB by default) + staging_size_bytes = int(DEFAULT_TRITON_STAGING_BUFFER_SIZE_MB * 1e6) + staging_elements = staging_size_bytes // kv_elem_bytes + self.triton_staging_buffer = torch.empty( + staging_elements, dtype=kv_dtype, device=f"cuda:{self.kv_args.gpu_id}" + ) + + # Get or create shared pinned buffer pool for this GPU + pinned_size_bytes = int( + getattr(self._server_args, "nixl_cpu_buffer_size_gb", 16.0) * 1e9 + ) + self._pinned_pool = PinnedBufferPool.get_or_create( + gpu_id=self.kv_args.gpu_id, + dtype=kv_dtype, + total_size_bytes=pinned_size_bytes, + ) + + logger.info( + f"[TRITON-KV] Initialized transfer buffers: " + f"staging={self.triton_staging_buffer.nbytes / 1e6:.2f}MB (GPU), " + f"shared_pinned_pool={pinned_size_bytes / 1e9:.2f}GB (CPU)" + ) + def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo): agent_name = decode_kv_args.agent_name if agent_name in self.decode_kv_args_table: @@ -583,6 +723,369 @@ def make_req_array(addr_chunks, size, gpu): return xfer_handle + def _expand_pages_to_slots( + self, + page_indices: npt.NDArray[np.int32], + page_size: int, + device: torch.device, + ) -> torch.Tensor: + """Expand page indices to slot indices (each page has page_size slots).""" + pages = torch.from_numpy(page_indices).to(device, dtype=torch.int64) + offsets = torch.arange(page_size, device=device, dtype=torch.int64) + return (pages.unsqueeze(1) * page_size + offsets).flatten() + + def send_kvcache_triton( + self, + peer_name: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_pinned_ptr: int, + dst_pinned_size: int, + notif: str, + head_start: int = 0, + num_heads_to_send: int = None, + dst_head_offset: int = 0, + ): + """ + Send KV cache using Triton gather kernel + single NIXL transfer. + + This method: + 1. Allocates a region from the shared pinned buffer pool + 2. Uses gather_kv_to_pinned_all_layers to collect scattered KV data into the region + 3. Records a CUDA event and returns (event, post_fn) + + The caller should poll event.query() and call post_fn() when the event fires. + post_fn() initiates the NIXL transfer and returns (handles, pool_allocations). + """ + gather_kv_all_layers, _ = _import_triton_kv_transfer() + if gather_kv_all_layers is None: + raise RuntimeError( + "[TRITON-KV] Triton KV transfer not available. " + "Make sure triton is installed." + ) + + if self.kv_args.k_buffers is None or self.kv_args.v_buffers is None: + raise RuntimeError( + "[TRITON-KV] k_buffers and v_buffers must be set in KVArgs " + "when using Triton KV transfer." + ) + + if self._pinned_pool is None: + raise RuntimeError( + "[TRITON-KV] Pinned buffer pool not initialized." + ) + + k_buffers = self.kv_args.k_buffers + v_buffers = self.kv_args.v_buffers + num_layers = len(k_buffers) + num_heads = k_buffers[0].shape[1] + head_dim = k_buffers[0].shape[2] + device = k_buffers[0].device + + if num_heads_to_send is None: + num_heads_to_send = num_heads - head_start + + # Convert page indices to slot indices + page_size = self.kv_args.page_size + slot_indices_tensor = self._expand_pages_to_slots( + prefill_kv_indices, page_size, device + ).to(torch.int32) + num_tokens = len(slot_indices_tensor) + + # Calculate transfer size + bytes_per_element = k_buffers[0].element_size() + transfer_elements = num_layers * 2 * num_tokens * num_heads_to_send * head_dim + transfer_bytes = transfer_elements * bytes_per_element + + # Allocate region from shared pinned buffer pool + src_offset, buffer_region = self._pinned_pool.allocate(transfer_bytes) + + logger.debug( + f"[TRITON-KV] send_kvcache_triton: {num_tokens} tokens, {num_layers} layers, " + f"heads [{head_start}:{head_start + num_heads_to_send}], " + f"transfer_size={transfer_bytes / 1e6:.2f}MB, pool_offset={src_offset}" + ) + + # Create pointer tensors (cached for reuse) + if not hasattr(self, '_k_data_ptrs') or self._k_data_ptrs is None: + self._k_data_ptrs = torch.tensor( + [x.data_ptr() for x in k_buffers], dtype=torch.uint64, device=device + ) + self._v_data_ptrs = torch.tensor( + [x.data_ptr() for x in v_buffers], dtype=torch.uint64, device=device + ) + self._src_slot_stride = k_buffers[0].stride(0) + self._src_head_stride = k_buffers[0].stride(1) + + # Gather KV data to allocated region using single-kernel Triton (device->host) + gather_kv_all_layers( + k_data_ptrs=self._k_data_ptrs, + v_data_ptrs=self._v_data_ptrs, + slot_indices=slot_indices_tensor, + pinned_output=buffer_region, + head_start=head_start, + num_heads_to_gather=num_heads_to_send, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=self._src_slot_stride, + src_head_stride=self._src_head_stride, + kv_elem_bytes=bytes_per_element, + ) + + # Record CUDA event — poll() will call post_fn() once event.query() is True, + # ensuring the gather kernel has written all data to pinned memory before NIXL reads it. + event = torch.cuda.Event() + event.record() + + # Capture variables needed by post_fn + head_stride_bytes = num_layers * 2 * num_tokens * head_dim * bytes_per_element + dst_offset = dst_head_offset * head_stride_bytes + buf_ptr = buffer_region.data_ptr() + pool_ref = self._pinned_pool + + def post_fn(): + if dst_pinned_ptr == 0: + pool_ref.release(src_offset) + raise RuntimeError( + f"[TRITON-KV] Invalid dst_pinned_ptr=0 for {peer_name}." + ) + + # Checksum the gather output AFTER the CUDA event has fired, + # confirming the gather kernel completed before NIXL reads it. + room_for_log = int(notif.split("_")[0]) + logger.warning( + f"[DBG-NIXL-WRITE] room={room_for_log} " + f"src=0x{buf_ptr:x} dst=0x{dst_pinned_ptr + dst_offset:x} " + f"size={transfer_bytes}" + ) + if _NIXL_DEBUG_CHECKSUM: + _kv_checksum( + buffer_region, + f"PREFILL-AFTER-GATHER peer={peer_name}", + room_for_log, + ) + + src_addrs = [(buf_ptr, transfer_bytes, 0)] + dst_addrs = [(dst_pinned_ptr + dst_offset, transfer_bytes, 0)] + + src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM") + + xfer_handle = self.agent.initialize_xfer( + "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii") + ) + if not xfer_handle: + pool_ref.release(src_offset) + raise Exception("[TRITON-KV] Failed to create Triton KV transfer") + + state = self.agent.transfer(xfer_handle) + if state == "ERR": + pool_ref.release(src_offset) + raise Exception("[TRITON-KV] Failed to post Triton KV transfer") + + return [xfer_handle], [(pool_ref, src_offset)] + + return event, post_fn + + def _send_kvcache_triton_batched( + self, + requests: List[tuple], + prefill_kv_indices: npt.NDArray[np.int32], + total_heads: int, + ): + """ + Batched KV transfer: ONE gather of ALL heads, then slice buffer for parallel NIXL transfers. + + Args: + requests: List of (agent_name, dst_pinned_ptr, dst_pinned_size, notif, head_start, num_heads) + prefill_kv_indices: Page indices to transfer + total_heads: Total number of KV heads on this prefill rank + + Returns: + Tuple of (event, post_fn) where post_fn() initiates all NIXL transfers and + returns (handles, pool_allocations). + """ + gather_kv_all_layers, _ = _import_triton_kv_transfer() + if gather_kv_all_layers is None: + raise RuntimeError("[TRITON-KV] Triton KV transfer not available.") + + if self.kv_args.k_buffers is None or self.kv_args.v_buffers is None: + raise RuntimeError("[TRITON-KV] k_buffers and v_buffers must be set.") + + if self._pinned_pool is None: + raise RuntimeError("[TRITON-KV] Pinned buffer pool not initialized.") + + k_buffers = self.kv_args.k_buffers + v_buffers = self.kv_args.v_buffers + num_layers = len(k_buffers) + head_dim = k_buffers[0].shape[2] + device = k_buffers[0].device + + # Convert page indices to slot indices + page_size = self.kv_args.page_size + slot_indices_tensor = self._expand_pages_to_slots( + prefill_kv_indices, page_size, device + ).to(torch.int32) + num_tokens = len(slot_indices_tensor) + + # Calculate total buffer size for ALL heads + bytes_per_element = k_buffers[0].element_size() + total_transfer_bytes = num_layers * 2 * num_tokens * total_heads * head_dim * bytes_per_element + + # Allocate ONE buffer from pool for all heads + src_offset, buffer_region = self._pinned_pool.allocate(total_transfer_bytes) + + # Create pointer tensors (cached for reuse) + if not hasattr(self, '_k_data_ptrs') or self._k_data_ptrs is None: + self._k_data_ptrs = torch.tensor( + [x.data_ptr() for x in k_buffers], dtype=torch.uint64, device=device + ) + self._v_data_ptrs = torch.tensor( + [x.data_ptr() for x in v_buffers], dtype=torch.uint64, device=device + ) + self._src_slot_stride = k_buffers[0].stride(0) + self._src_head_stride = k_buffers[0].stride(1) + + # ONE gather of ALL heads + gather_kv_all_layers( + k_data_ptrs=self._k_data_ptrs, + v_data_ptrs=self._v_data_ptrs, + slot_indices=slot_indices_tensor, + pinned_output=buffer_region, + head_start=0, + num_heads_to_gather=total_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=self._src_slot_stride, + src_head_stride=self._src_head_stride, + kv_elem_bytes=bytes_per_element, + ) + + # Record CUDA event — poll() will call post_fn() once event.query() is True, + # ensuring the gather kernel has written all data to pinned memory before NIXL reads it. + event = torch.cuda.Event() + event.record() + + # Capture variables needed by post_fn + head_stride_bytes = num_layers * 2 * num_tokens * head_dim * bytes_per_element + buf_data_ptr = buffer_region.data_ptr() + pool_ref = self._pinned_pool + + def post_fn(): + handles = [] + # Checksum the FULL gather buffer once (after the CUDA event fires) + if _NIXL_DEBUG_CHECKSUM and requests: + first_notif = requests[0][3] + room_for_log = int(first_notif.split("_")[0]) + _kv_checksum( + buffer_region, + f"PREFILL-BATCHED-AFTER-GATHER nreqs={len(requests)}", + room_for_log, + ) + for agent_name, dst_pinned_ptr, dst_pinned_size, notif, head_start, num_heads in requests: + src_slice_ptr = buf_data_ptr + head_start * head_stride_bytes + slice_bytes = num_heads * head_stride_bytes + + if dst_pinned_ptr == 0: + pool_ref.release(src_offset) + raise RuntimeError( + f"[TRITON-KV-BATCHED] Invalid dst_pinned_ptr=0 for {agent_name}." + ) + + src_addrs = [(src_slice_ptr, slice_bytes, 0)] + dst_addrs = [(dst_pinned_ptr, slice_bytes, 0)] + + src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM") + + xfer_handle = self.agent.initialize_xfer( + "WRITE", src_descs, dst_descs, agent_name, notif.encode("ascii") + ) + if not xfer_handle: + pool_ref.release(src_offset) + raise Exception(f"[TRITON-KV-BATCHED] Failed to create transfer to {agent_name}") + + state = self.agent.transfer(xfer_handle) + if state == "ERR": + pool_ref.release(src_offset) + raise Exception(f"[TRITON-KV-BATCHED] Failed to post transfer to {agent_name}") + + handles.append(xfer_handle) + + return handles, [(pool_ref, src_offset)] + + return event, post_fn + + def scatter_received_kv( + self, + kv_indices: npt.NDArray[np.int32], + head_start: int = 0, + num_heads_received: int = None, + pinned_buffer: Optional[torch.Tensor] = None, + ): + """ + Scatter received KV data from pinned buffer to KV cache. + + Called on the receiver side after NIXL transfer completes. + """ + _, scatter_kv_all_layers = _import_triton_kv_transfer() + if scatter_kv_all_layers is None: + raise RuntimeError("[TRITON-KV] Triton KV transfer not available.") + + if self.kv_args.k_buffers is None or self.kv_args.v_buffers is None: + raise RuntimeError("[TRITON-KV] k_buffers and v_buffers must be set.") + + if self._pinned_pool is None: + raise RuntimeError("[TRITON-KV] Pinned buffer pool not initialized.") + + k_buffers = self.kv_args.k_buffers + v_buffers = self.kv_args.v_buffers + num_layers = len(k_buffers) + num_heads = k_buffers[0].shape[1] + head_dim = k_buffers[0].shape[2] + device = k_buffers[0].device + + if num_heads_received is None: + num_heads_received = num_heads - head_start + + # Convert page indices to slot indices + page_size = self.kv_args.page_size + slot_indices_tensor = self._expand_pages_to_slots( + kv_indices, page_size, device + ).to(torch.int32) + num_tokens = len(slot_indices_tensor) + + bytes_per_element = k_buffers[0].element_size() + + # Create pointer tensors (cached for reuse) + if not hasattr(self, '_k_data_ptrs') or self._k_data_ptrs is None: + self._k_data_ptrs = torch.tensor( + [x.data_ptr() for x in k_buffers], dtype=torch.uint64, device=device + ) + self._v_data_ptrs = torch.tensor( + [x.data_ptr() for x in v_buffers], dtype=torch.uint64, device=device + ) + self._dst_slot_stride = k_buffers[0].stride(0) + self._dst_head_stride = k_buffers[0].stride(1) + + # Scatter from the per-request allocated region (or whole pool as fallback) to KV cache. + # No CPU sync needed: the scatter kernel runs on the default CUDA stream, and the + # subsequent model forward pass also runs on that stream, so GPU stream ordering + # guarantees the scatter completes before the forward reads the KV cache. + input_buffer = pinned_buffer if pinned_buffer is not None else self._pinned_pool.buffer + scatter_kv_all_layers( + pinned_input=input_buffer, + k_data_ptrs=self._k_data_ptrs, + v_data_ptrs=self._v_data_ptrs, + slot_indices=slot_indices_tensor, + head_start=head_start, + num_heads_to_scatter=num_heads_received, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=self._dst_slot_stride, + dst_head_stride=self._dst_head_stride, + kv_elem_bytes=bytes_per_element, + ) + def send_aux( self, peer_name: str, @@ -731,11 +1234,104 @@ def add_transfer_request( aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): + """ + Add a transfer request for KV cache data. + + Returns: + Tuple of (handles, pool_allocations) where: + - handles: List of NIXL transfer handles + - pool_allocations: List of (pool, offset) tuples for later release + """ assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() handles = [] + pool_allocations = [] + pending_posts = [] + + # Filter out dummy requests for CPU buffer batched path detection + active_reqs = [req for req in reqs_to_be_processed if not req.is_dummy()] + + # Detect batched Triton case: prefill_tp < decode_tp with multiple destinations + if active_reqs: + first_decode_info = self.decode_kv_args_table.get(active_reqs[0].agent_name) + if first_decode_info: + prefill_tp_size = self.attn_tp_size + decode_tp_size = first_decode_info.decode_tp_size + + use_batched = ( + self.nixl_use_cpu_buffer + and prefill_tp_size < decode_tp_size + and all( + self.decode_kv_args_table[r.agent_name].dst_pinned_ptr != 0 + for r in active_reqs + ) + and self.kv_args.k_buffers is not None + and self.kv_args.v_buffers is not None + and not self.is_mla_backend + ) + + if use_batched: + # Collect batch request info + num_kv_heads = self.kv_args.kv_head_num + total_prefill_heads = num_kv_heads * prefill_tp_size + heads_per_decode_rank = total_prefill_heads // decode_tp_size + # Decode ranks that connect to this prefill rank are grouped in + # a contiguous block. Use the relative rank within that block so + # head_start stays in [0, num_kv_heads). + decode_per_prefill = decode_tp_size // prefill_tp_size + + batch_requests = [] + for req in active_reqs: + decode_info = self.decode_kv_args_table[req.agent_name] + decode_tp_rank = decode_info.decode_tp_rank % decode_tp_size + relative_decode_rank = decode_tp_rank % decode_per_prefill + head_start = relative_decode_rank * heads_per_decode_rank + logger.debug( + f"[MIXED-TP-BATCHED] prefill_tp={prefill_tp_size}, " + f"decode_tp={decode_tp_size}, decode_tp_rank={decode_tp_rank}, " + f"decode_per_prefill={decode_per_prefill}, " + f"relative_decode_rank={relative_decode_rank}, " + f"head_start={head_start}, heads_per_decode={heads_per_decode_rank}" + ) + notif = f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}" + # Use per-request allocated ptr if available, else fall back to static pool start + effective_dst_ptr = req.dst_pinned_ptr if req.dst_pinned_ptr != 0 else decode_info.dst_pinned_ptr + effective_dst_size = req.dst_pinned_size if req.dst_pinned_size != 0 else decode_info.dst_pinned_size + batch_requests.append(( + req.agent_name, + effective_dst_ptr, + effective_dst_size, + notif, + head_start, + heads_per_decode_rank, + )) + + batch_event, batch_post_fn = self._send_kvcache_triton_batched( + batch_requests, kv_indices, num_kv_heads + ) + pending_posts.append((batch_event, batch_post_fn)) + + # Handle aux data separately + if is_last: + for req in active_reqs: + assert aux_index is not None + decode_info = self.decode_kv_args_table[req.agent_name] + aux_xfer_handle = self.send_aux( + req.agent_name, + aux_index, + decode_info.dst_aux_ptrs, + req.dst_aux_index, + f"{req.room}_aux", + ) + handles.append(aux_xfer_handle) + + if is_last: + del self.transfer_infos[bootstrap_room] + + return handles, pool_allocations, pending_posts + for req in reqs_to_be_processed: assert bootstrap_room == req.room if req.is_dummy(): @@ -746,36 +1342,111 @@ def add_transfer_request( assert req.agent_name in self.decode_kv_args_table notif = f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}" - decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size + decode_info = self.decode_kv_args_table[req.agent_name] + decode_tp_size = decode_info.decode_tp_size + + # Check if CPU buffer Triton transfer is enabled and supported + prefill_tp_size = self.attn_tp_size + use_cpu_buffer = ( + self.nixl_use_cpu_buffer + and decode_info.dst_pinned_ptr != 0 + and self.kv_args.k_buffers is not None + and self.kv_args.v_buffers is not None + and not self.is_mla_backend + ) - if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): + kv_xfer_handle = None + if use_cpu_buffer and prefill_tp_size >= decode_tp_size: + # Triton CPU buffer path for same-TP or prefill_tp > decode_tp. + # + # Bug fix: chunked prefill sends kv_indices in multiple send() + # calls (chunk_id 0, 1, ...). Without accumulation, each chunk + # computes dst_offset using its own num_tokens, so chunk N + # overwrites chunk N-1 at the same destination offset, leaving + # the second half of the decode CPU buffer as zeros. + # + # Fix: accumulate all per-chunk kv_indices and issue a single + # gather + NIXL WRITE only when is_last=True, covering the full + # N_total-token buffer in one shot. Decode always receives + # exactly one KV notification per prefill TP rank (chunk_id=0, + # is_last=True), so its is_done() logic is unaffected. + num_kv_heads = self.kv_args.kv_head_num + local_tp_rank = self.kv_args.engine_rank % prefill_tp_size + + if prefill_tp_size > decode_tp_size: + head_start = 0 + num_heads_to_send = num_kv_heads + # Use the rank relative to the decode bucket so dst_head_offset + # stays within [0, num_kv_heads * prefill_ranks_per_decode). + prefill_ranks_per_decode = prefill_tp_size // decode_tp_size + dst_head_offset = (local_tp_rank % prefill_ranks_per_decode) * num_kv_heads + logger.debug( + f"[MIXED-TP] prefill_tp={prefill_tp_size}, decode_tp={decode_tp_size}, " + f"local_tp_rank={local_tp_rank}, num_kv_heads={num_kv_heads}, " + f"prefill_ranks_per_decode={prefill_ranks_per_decode}, " + f"dst_head_offset={dst_head_offset}" + ) + else: + head_start = 0 + num_heads_to_send = num_kv_heads + dst_head_offset = 0 + + # Accumulate kv_indices across chunks for this request. + pending_kv = self._cpu_pending_kv.setdefault(bootstrap_room, []) + pending_kv.append(kv_indices) + + if is_last: + # Combine all accumulated chunks into one contiguous array. + all_kv_indices = ( + np.concatenate(pending_kv) if len(pending_kv) > 1 + else pending_kv[0] + ) + del self._cpu_pending_kv[bootstrap_room] + + # Use per-request allocated ptr if available, else fall back to static pool start + effective_dst_ptr = req.dst_pinned_ptr if req.dst_pinned_ptr != 0 else decode_info.dst_pinned_ptr + effective_dst_size = req.dst_pinned_size if req.dst_pinned_size != 0 else decode_info.dst_pinned_size + + # Always use chunk_id=0 / is_last=True for the CPU buffer + # path: we emit exactly one KV notification per request. + cpu_notif = f"{req.room}_kv_0_1_{self.kv_args.pp_rank}" + kv_event, kv_post_fn = self.send_kvcache_triton( + peer_name=req.agent_name, + prefill_kv_indices=all_kv_indices, + dst_pinned_ptr=effective_dst_ptr, + dst_pinned_size=effective_dst_size, + notif=cpu_notif, + head_start=head_start, + num_heads_to_send=num_heads_to_send, + dst_head_offset=dst_head_offset, + ) + pending_posts.append((kv_event, kv_post_fn)) + # else: not the last chunk — accumulate only, defer send. + elif self.is_mla_backend or (decode_tp_size == self.attn_tp_size): kv_xfer_handle = self.send_kvcache( req.agent_name, kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + decode_info.dst_kv_ptrs, chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, + decode_info.gpu_id, notif, ) else: kv_xfer_handle = self.send_kvcache_slice( req.agent_name, kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + decode_info.dst_kv_ptrs, chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, + decode_info.gpu_id, notif, prefill_tp_size=self.attn_tp_size, decode_tp_size=decode_tp_size, - decode_tp_rank=self.decode_kv_args_table[ - req.agent_name - ].decode_tp_rank, - dst_kv_item_len=self.decode_kv_args_table[ - req.agent_name - ].dst_kv_item_len, + decode_tp_rank=decode_info.decode_tp_rank, + dst_kv_item_len=decode_info.dst_kv_item_len, ) - handles.append(kv_xfer_handle) + if kv_xfer_handle is not None: + handles.append(kv_xfer_handle) # Only the last chunk we need to send the aux data. if is_last: if state_indices is not None: @@ -796,48 +1467,74 @@ def add_transfer_request( aux_xfer_handle = self.send_aux( req.agent_name, aux_index, - self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, + decode_info.dst_aux_ptrs, req.dst_aux_index, f"{req.room}_aux", ) handles.append(aux_xfer_handle) if is_last: del self.transfer_infos[bootstrap_room] - return handles + return handles, pool_allocations, pending_posts + + def _drain_deferred_pool_releases(self) -> None: + """Release pinned-buffer regions whose scatter kernels have completed. + + Checks each pending (CUDA event, pool offset) pair. If the event has + fired (GPU kernel done), the pool region is released immediately so it + can be reused by the next NIXL transfer. Pending entries whose events + have *not* yet fired are kept for the next call. + + This is called both from ``update_transfer_status()`` (normal poll path) + and from ``NixlKVReceiver.init()`` *before* blocking on pool allocation, + to avoid a deadlock where the allocator waits for space that would only + be freed after ``poll()`` runs on an already-allocated request. + """ + if not self._pending_pool_releases or self._pinned_pool is None: + return + remaining = [] + for event, offset in self._pending_pool_releases: + if event.query(): + self._pinned_pool.release(offset) + else: + remaining.append((event, offset)) + self._pending_pool_releases = remaining def update_transfer_status(self): + # Drain deferred pinned-buffer releases from completed scatter kernels. + self._drain_deferred_pool_releases() + # Process notifications from received transfers. notif_map = self.agent.get_new_notifs() for peer_name, messages in notif_map.items(): - # We could also check that self.bootstrap_info['agent_name'] matches - # the message sender. But the bootstrap room alone should be - # sufficient to map the status. + # Use peer_name as the unique sender key. This correctly handles + # mixed TP where multiple prefill TP ranks (each with a unique + # NIXL agent/peer_name) send to the same decode rank. for msg in messages: components = msg.decode("ascii").split("_", 4) room = int(components[0]) if components[1] == "kv": chunk_id = int(components[2]) is_last = bool(int(components[3])) - pp_rank = int(components[4]) if len(components) > 4 else 0 - # Track received chunks per pp_rank - self.transfer_statuses[room].received_kvs_per_pp[pp_rank].add( - chunk_id - ) + sender_key = peer_name + # Track received chunks per sender + self.transfer_statuses[room].received_kvs_per_sender[ + sender_key + ].add(chunk_id) if is_last: - # Record expected chunk count for this pp_rank - self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = ( - chunk_id + 1 - ) - # Set num_pp_ranks_expected from table (or default to 1) - if self.transfer_statuses[room].num_pp_ranks_expected is None: - self.transfer_statuses[room].num_pp_ranks_expected = ( + # Record expected chunk count for this sender + self.transfer_statuses[room].expected_kvs_per_sender[ + sender_key + ] = (chunk_id + 1) + # Set num_senders_expected from table (or default to 1) + if self.transfer_statuses[room].num_senders_expected is None: + self.transfer_statuses[room].num_senders_expected = ( self.required_prefill_response_num_table.get(room, 1) ) elif components[1] == "aux": self.transfer_statuses[room].received_aux = True elif components[1] == "state": - pp_rank = int(components[2]) if len(components) > 2 else 0 - self.transfer_statuses[room].received_state_per_pp.add(pp_rank) + sender_key = peer_name + self.transfer_statuses[room].received_state_per_sender.add(sender_key) def check_transfer_done(self, room: int): if room not in self.transfer_statuses: @@ -893,6 +1590,10 @@ def __init__( ): super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank) self.xfer_handles = [] + # Track pool allocations for release when transfer completes + self._pool_allocations: List[tuple] = [] + # Pending (event, post_fn) pairs: NIXL not yet posted, waiting for gather kernel + self._pending_posts: List[tuple] = [] self.has_sent = False self.chunk_id = 0 @@ -905,16 +1606,20 @@ def send( self.curr_idx += len(kv_indices) is_last = self.curr_idx == self.num_kv_indices - new_xfer_handles = self.kv_mgr.add_transfer_request( - self.bootstrap_room, - kv_indices, - index_slice, - is_last, - self.chunk_id, - self.aux_index, - state_indices, + new_xfer_handles, new_pool_allocations, new_pending_posts = ( + self.kv_mgr.add_transfer_request( + self.bootstrap_room, + kv_indices, + index_slice, + is_last, + self.chunk_id, + self.aux_index, + state_indices, + ) ) self.xfer_handles.extend(new_xfer_handles) + self._pool_allocations.extend(new_pool_allocations) + self._pending_posts.extend(new_pending_posts) self.chunk_id += 1 if is_last: self.has_sent = True @@ -923,10 +1628,36 @@ def send( def poll(self) -> KVPoll: if not self.has_sent: return self.kv_mgr.check_status(self.bootstrap_room) + + # Drain pending gather events: once a CUDA event fires, post the NIXL transfer. + if self._pending_posts: + remaining = [] + for event, post_fn in self._pending_posts: + if event.query(): + new_handles, new_allocs = post_fn() + self.xfer_handles.extend(new_handles) + self._pool_allocations.extend(new_allocs) + else: + remaining.append((event, post_fn)) + self._pending_posts = remaining + if self._pending_posts: + return KVPoll.WaitingForInput # type: ignore + + if not self.xfer_handles: + return KVPoll.WaitingForInput # type: ignore + states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] if all([x == "DONE" for x in states]): + # Release pool allocations now that all transfers are complete + for pool, offset in self._pool_allocations: + pool.release(offset) + self._pool_allocations.clear() return KVPoll.Success # type: ignore if any([x == "ERR" for x in states]): + # Release pool allocations on error too + for pool, offset in self._pool_allocations: + pool.release(offset) + self._pool_allocations.clear() raise Exception("KVSender transfer encountered an error.") return KVPoll.WaitingForInput # type: ignore @@ -952,6 +1683,12 @@ def __init__( self.bootstrap_room ) self.init_time = None + # Store kv_indices for Triton scatter after transfer completes + self._triton_kv_indices: Optional[npt.NDArray[np.int32]] = None + self._triton_scatter_done = False + # Per-request pinned buffer allocation on the receive side + self._recv_pool_offset: Optional[int] = None + self._recv_pool_buffer_view: Optional[torch.Tensor] = None def init( self, @@ -966,6 +1703,40 @@ def init( self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) return + # For CPU buffer transfers, allocate a per-request region from the receive pool. + # Each concurrent request gets its own region so NIXL writes don't overwrite each other. + # We send the allocated ptr to the prefill so it writes to our unique offset. + recv_pinned_ptr = 0 + recv_pinned_size = 0 + if ( + self.kv_mgr.nixl_use_cpu_buffer + and self.kv_mgr._pinned_pool is not None + and self.kv_mgr.kv_args.k_buffers is not None + ): + k_buffers = self.kv_mgr.kv_args.k_buffers + num_layers = len(k_buffers) + num_heads = k_buffers[0].shape[1] + head_dim = k_buffers[0].shape[2] + bytes_per_element = k_buffers[0].element_size() + num_tokens = len(kv_indices) * self.kv_mgr.kv_args.page_size + recv_pinned_size = ( + num_layers * 2 * num_tokens * num_heads * head_dim * bytes_per_element + ) + # Drain any deferred releases from completed scatter kernels before + # allocating, so we don't block if a previous request's scatter has + # already finished but its pool region hasn't been freed yet. + self.kv_mgr._drain_deferred_pool_releases() + recv_offset, recv_buffer_view = self.kv_mgr._pinned_pool.allocate( + recv_pinned_size + ) + self._recv_pool_offset = recv_offset + self._recv_pool_buffer_view = recv_buffer_view + recv_pinned_ptr = self.kv_mgr._pinned_pool.buffer.data_ptr() + recv_offset + logger.warning( + f"[DBG-ALLOC] room={self.bootstrap_room} offset={recv_offset} " + f"ptr=0x{recv_pinned_ptr:x} size={recv_pinned_size}" + ) + for bootstrap_info in self.bootstrap_infos: logger.debug( f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" @@ -991,6 +1762,8 @@ def init( if not is_dummy and state_indices is not None else b"" ), + str(recv_pinned_ptr).encode("ascii"), + str(recv_pinned_size).encode("ascii"), ] ) @@ -1001,6 +1774,11 @@ def init( self.started_transfer = True self.init_time = time.time() + # Store kv_indices for Triton scatter after transfer completes + if self.kv_mgr.nixl_use_cpu_buffer: + self._triton_kv_indices = kv_indices.copy() + self._triton_scatter_done = False + def poll(self) -> KVPoll: if self.conclude_state is not None: return self.conclude_state @@ -1020,6 +1798,7 @@ def poll(self) -> KVPoll: self.bootstrap_room, f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput", ) + self._release_recv_pool() self.conclude_state = KVPoll.Failed return KVPoll.Failed @@ -1030,16 +1809,88 @@ def poll(self) -> KVPoll: ) # Check if the transfer failed if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed(): + self._release_recv_pool() self.conclude_state = KVPoll.Failed logger.error( f"Transfer for room {self.bootstrap_room} failed due to node failure" ) else: + # For CPU buffer transfer, scatter received data from pinned buffer to GPU KV cache + if ( + self.kv_mgr.nixl_use_cpu_buffer + and self._triton_kv_indices is not None + and not self._triton_scatter_done + ): + try: + # Checksum the receive buffer AFTER the NIXL notification + # confirms the transfer is done, but BEFORE the scatter + # kernel reads it. Compare to the prefill-side checksum + # in the prefill log for the same room to detect: + # - all-zero buffer → NIXL write never reached this buffer + # - value mismatch → wrong buffer address / pool aliasing + # - partial zeros → transfer not fully written yet + buf_ptr = self._recv_pool_buffer_view.data_ptr() if self._recv_pool_buffer_view is not None else 0 + logger.warning( + f"[DBG-SCATTER] room={self.bootstrap_room} " + f"buf_ptr=0x{buf_ptr:x} offset={self._recv_pool_offset}" + ) + if _NIXL_DEBUG_CHECKSUM and self._recv_pool_buffer_view is not None: + _kv_checksum( + self._recv_pool_buffer_view, + "DECODE-BEFORE-SCATTER", + self.bootstrap_room, + ) + self.kv_mgr.scatter_received_kv( + kv_indices=self._triton_kv_indices, + head_start=0, + num_heads_received=None, + pinned_buffer=self._recv_pool_buffer_view, + ) + self._triton_scatter_done = True + # Defer pool release until scatter kernel completes on GPU. + # Record a CUDA event now (after kernel launch) and hand the + # offset to the manager's deferred-release list. The pool + # region must not be reused until the GPU kernel has finished + # reading from it, otherwise a concurrent NIXL write could + # corrupt the buffer before the scatter is done. + self._defer_recv_pool_release() + except Exception as e: + logger.error( + f"[TRITON-KV] Scatter failed: room={self.bootstrap_room}, error={e}" + ) + self._release_recv_pool() + self.conclude_state = KVPoll.Failed + del self.kv_mgr.transfer_statuses[self.bootstrap_room] + return KVPoll.Failed + self.conclude_state = KVPoll.Success del self.kv_mgr.transfer_statuses[self.bootstrap_room] return self.conclude_state # type: ignore return KVPoll.WaitingForInput # type: ignore + def _release_recv_pool(self): + """Release the per-request pinned buffer allocation immediately.""" + if self._recv_pool_offset is not None and self.kv_mgr._pinned_pool is not None: + self.kv_mgr._pinned_pool.release(self._recv_pool_offset) + self._recv_pool_offset = None + self._recv_pool_buffer_view = None + + def _defer_recv_pool_release(self): + """Defer pinned-buffer release until the scatter kernel finishes on GPU. + + Records a CUDA event on the current stream immediately after the scatter + kernel launch and hands the (event, offset) pair to the manager's + deferred-release list. The pool region is freed in + ``update_transfer_status()`` once ``event.query()`` returns True. + """ + if self._recv_pool_offset is None: + return + event = torch.cuda.Event() + event.record() + self.kv_mgr._pending_pool_releases.append((event, self._recv_pool_offset)) + self._recv_pool_offset = None + self._recv_pool_buffer_view = None + def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: sock, lock = self._connect_to_bootstrap_server(bootstrap_info) @@ -1053,6 +1904,15 @@ def _register_kv_args(self): struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs ) + # Get pinned buffer info for CPU buffer KV transfer + pinned_ptr = 0 + pinned_size = 0 + if ( + self.kv_mgr.nixl_use_cpu_buffer + and self.kv_mgr._pinned_pool is not None + ): + pinned_ptr, pinned_size = self.kv_mgr._pinned_pool.get_buffer_info() + with lock: sock.send_multipart( [ @@ -1069,6 +1929,8 @@ def _register_kv_args(self): str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"), str(self.kv_mgr.kv_args.engine_rank).encode("ascii"), str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"), + str(pinned_ptr).encode("ascii"), + str(pinned_size).encode("ascii"), ] ) diff --git a/python/sglang/srt/disaggregation/nixl/pinned_buffer_pool.py b/python/sglang/srt/disaggregation/nixl/pinned_buffer_pool.py new file mode 100644 index 000000000000..a0160f151019 --- /dev/null +++ b/python/sglang/srt/disaggregation/nixl/pinned_buffer_pool.py @@ -0,0 +1,234 @@ +""" +Unified pinned buffer pool for NIXL KV transfers. + +This module provides a per-GPU singleton pool that shares a single pinned buffer +across all NixlKVManager instances on the same GPU, avoiding double allocation +when a decode node does both receiving (from prefill) and sending (for decode->decode +migration). +""" + +from __future__ import annotations + +import logging +import threading +import time +from typing import Any, Dict, List, Optional, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +class PinnedBufferPool: + """ + Per-GPU singleton that provides a shared pinned buffer for KV transfers. + Both receiver (prefill->decode) and sender (decode->decode migration) use this. + + Pre-allocates the full buffer at creation to avoid runtime latency spikes. + Uses range-based allocation for variable-sized concurrent transfers. + """ + + _instances: Dict[int, "PinnedBufferPool"] = {} # gpu_id -> pool + _lock = threading.Lock() + + @classmethod + def get_or_create( + cls, + gpu_id: int, + dtype: torch.dtype, + total_size_bytes: int, + ) -> "PinnedBufferPool": + """Get existing pool for this GPU or create one (allocates immediately).""" + with cls._lock: + if gpu_id not in cls._instances: + cls._instances[gpu_id] = cls(gpu_id, dtype, total_size_bytes) + return cls._instances[gpu_id] + + @classmethod + def clear_instances(cls): + """Clear all pool instances. Used for testing.""" + with cls._lock: + cls._instances.clear() + + def __init__(self, gpu_id: int, dtype: torch.dtype, total_size_bytes: int): + self.gpu_id = gpu_id + self.dtype = dtype + self.total_size_bytes = total_size_bytes + + # Pre-allocate the full buffer NOW + elem_size = torch.tensor([], dtype=self.dtype).element_size() + num_elements = total_size_bytes // elem_size + logger.info( + f"[PinnedBufferPool] Pre-allocating {total_size_bytes / 1e9:.2f}GB " + f"pinned buffer for GPU {gpu_id}" + ) + self._buffer = torch.empty(num_elements, dtype=self.dtype, pin_memory=True) + logger.info(f"[PinnedBufferPool] Allocation complete") + + # Range tracking: list of (start, end) tuples for allocated regions + # Uses simple first-fit allocation + self._allocated_ranges: List[Tuple[int, int]] = [] # [(start, end), ...] + self._range_lock = threading.Lock() + self._range_available = threading.Condition(self._range_lock) + + # Track NIXL registrations per agent (each agent needs its own registration) + self._nixl_descs_by_agent: Dict[str, Any] = {} + self._warned_full = False + + def allocate( + self, size_bytes: int, timeout: Optional[float] = None + ) -> Tuple[int, torch.Tensor]: + """ + Allocate a contiguous region of the given size. + Blocks if no space available until space is released. + + Args: + size_bytes: Number of bytes to allocate + timeout: Optional timeout in seconds. None (default) waits indefinitely. + Only use finite timeout for testing. + + Returns: (offset_bytes, buffer_view) + """ + # Align to 256 bytes for better memory access + aligned_size = ((size_bytes + 255) // 256) * 256 + log_interval = 10.0 # Log status every 10 seconds while waiting + + with self._range_available: + deadline = time.time() + timeout if timeout is not None else None + last_log_time = 0.0 + + while True: + # Try to find a free region (first-fit) + offset = self._find_free_region(aligned_size) + if offset is not None: + # Mark as allocated + self._allocated_ranges.append((offset, offset + aligned_size)) + self._allocated_ranges.sort() # Keep sorted for efficient search + self._warned_full = False + + # Return view into buffer + elem_size = self._buffer.element_size() + start_elem = offset // elem_size + end_elem = start_elem + (aligned_size // elem_size) + return offset, self._buffer[start_elem:end_elem] + + # No space - log status periodically + now = time.time() + if now - last_log_time >= log_interval: + allocated_bytes = sum(e - s for s, e in self._allocated_ranges) + logger.warning( + f"[PinnedBufferPool] GPU {self.gpu_id}: Waiting for space. " + f"Need {size_bytes / 1e6:.1f}MB, " + f"buffer {allocated_bytes / 1e6:.1f}/{self.total_size_bytes / 1e6:.1f}MB used, " + f"{len(self._allocated_ranges)} active allocations. " + f"Consider increasing --nixl-cpu-buffer-size-gb if this persists." + ) + last_log_time = now + + # Check timeout (only used for testing) + if deadline is not None: + remaining = deadline - time.time() + if remaining <= 0: + raise RuntimeError( + f"[PinnedBufferPool] Timeout waiting for pinned buffer space. " + f"Needed {size_bytes} bytes ({size_bytes / 1e6:.2f}MB), " + f"total buffer {self.total_size_bytes} bytes ({self.total_size_bytes / 1e9:.2f}GB), " + f"allocated ranges: {len(self._allocated_ranges)}. " + f"Consider increasing --nixl-cpu-buffer-size-gb." + ) + wait_time = min(remaining, log_interval) + else: + wait_time = log_interval + + self._range_available.wait(timeout=wait_time) + + def _find_free_region(self, size_bytes: int) -> Optional[int]: + """Find first free region that can fit size_bytes. Returns offset or None.""" + if not self._allocated_ranges: + # Buffer is empty - use start + if size_bytes <= self.total_size_bytes: + return 0 + return None + + # Check gap before first allocation + if self._allocated_ranges[0][0] >= size_bytes: + return 0 + + # Check gaps between allocations + for i in range(len(self._allocated_ranges) - 1): + gap_start = self._allocated_ranges[i][1] + gap_end = self._allocated_ranges[i + 1][0] + if gap_end - gap_start >= size_bytes: + return gap_start + + # Check gap after last allocation + last_end = self._allocated_ranges[-1][1] + if self.total_size_bytes - last_end >= size_bytes: + return last_end + + return None + + def release(self, offset: int): + """Release a previously allocated region.""" + with self._range_available: + # Find and remove the range starting at this offset + original_len = len(self._allocated_ranges) + self._allocated_ranges = [ + (s, e) for s, e in self._allocated_ranges if s != offset + ] + if len(self._allocated_ranges) == original_len: + logger.warning( + f"[PinnedBufferPool] Attempted to release unknown offset {offset}" + ) + self._range_available.notify_all() + + def get_buffer_info(self) -> Tuple[int, int]: + """Return (data_ptr, nbytes) for NIXL registration.""" + return (self._buffer.data_ptr(), self._buffer.nbytes) + + def register_with_nixl(self, agent) -> Any: + """Register full buffer with NIXL agent. + + Each NIXL agent needs its own registration, even for the same buffer. + This is because NIXL descriptors are agent-specific. + """ + agent_name = agent.name + if agent_name in self._nixl_descs_by_agent: + logger.debug( + f"[PinnedBufferPool] Buffer already registered with agent {agent_name}" + ) + return self._nixl_descs_by_agent[agent_name] + + addr = [(self._buffer.data_ptr(), self._buffer.nbytes, 0, "")] + descs = agent.register_memory(addr, "DRAM") + if not descs: + raise Exception( + f"[PinnedBufferPool] NIXL memory registration failed for pinned buffer " + f"with agent {agent_name}" + ) + self._nixl_descs_by_agent[agent_name] = descs + logger.info( + f"[PinnedBufferPool] Registered pinned buffer with NIXL agent {agent_name}: " + f"{self._buffer.nbytes / 1e9:.2f}GB" + ) + return descs + + @property + def buffer(self) -> torch.Tensor: + """Get the underlying buffer tensor.""" + return self._buffer + + def get_stats(self) -> Dict[str, Any]: + """Get allocation statistics for monitoring.""" + with self._range_lock: + allocated_bytes = sum(e - s for s, e in self._allocated_ranges) + return { + "gpu_id": self.gpu_id, + "total_bytes": self.total_size_bytes, + "allocated_bytes": allocated_bytes, + "free_bytes": self.total_size_bytes - allocated_bytes, + "num_allocations": len(self._allocated_ranges), + "utilization": allocated_bytes / self.total_size_bytes + if self.total_size_bytes > 0 + else 0.0, + } diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index cddf469f7bba..7d2e4f01d754 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -162,6 +162,15 @@ def _init_kv_manager(self) -> CommonKVManager: ) kv_args.page_size = self.token_to_kv_pool.page_size + # Add tensor buffer references for CPU buffer KV transfer + if getattr(self.scheduler.server_args, "nixl_use_cpu_buffer", False): + if hasattr(self.token_to_kv_pool, "k_buffer") and hasattr( + self.token_to_kv_pool, "v_buffer" + ): + kv_args.k_buffers = self.token_to_kv_pool.k_buffer + kv_args.v_buffers = self.token_to_kv_pool.v_buffer + kv_args.head_dim = self.token_to_kv_pool.head_dim + kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos() ) diff --git a/python/sglang/srt/layers/attention/triton_ops/kv_transfer.py b/python/sglang/srt/layers/attention/triton_ops/kv_transfer.py new file mode 100644 index 000000000000..69425ee4e372 --- /dev/null +++ b/python/sglang/srt/layers/attention/triton_ops/kv_transfer.py @@ -0,0 +1,481 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Triton kernels for gathering/scattering KV cache data between GPU and pinned CPU memory. + +These kernels enable efficient KV cache transfers for disaggregated inference by: +1. Gathering scattered KV data from GPU into pinned CPU buffer (device -> host) +2. Scattering KV data from pinned CPU buffer to GPU KV cache (host -> device) + +Primary API: +- gather_kv_to_pinned_all_layers(): Gather KV from GPU to pinned CPU (single kernel) +- scatter_kv_with_staging_all_layers(): Scatter KV from pinned CPU to GPU (single kernel) + +Both kernels achieve ~100% of PCIe bandwidth with O(1) extra GPU memory overhead. +They process all layers in a single kernel launch using pointer tensors. + +Data layout: HEAD-FIRST [num_heads, num_layers, 2, num_tokens, head_dim] +This layout allows easy head slicing for mixed-TP transfers. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _gather_kv_all_layers_kernel( + # Pointers to all K buffers (tensor of uint64 pointers) + k_data_ptrs, + # Pointers to all V buffers (tensor of uint64 pointers) + v_data_ptrs, + # Slot indices to gather + slot_indices_ptr, + # Output pinned CPU buffer + output_ptr, + # Dimensions + num_layers, + num_tokens, + head_dim: tl.constexpr, + # Head slicing params + head_start: tl.constexpr, + num_heads_to_gather: tl.constexpr, + # Source strides (in elements) - same for all layers + src_slot_stride, + src_head_stride, + # Output layout strides (HEAD-FIRST) + out_head_stride, # = num_layers * 2 * num_tokens * head_dim + out_layer_stride, # = 2 * num_tokens * head_dim + out_kv_stride, # = num_tokens * head_dim + out_token_stride, # = head_dim + # Block sizes + BLOCK_TOKENS: tl.constexpr, + BLOCK_DIM: tl.constexpr, + # Element size: 1 for fp8, 2 for fp16/bf16 + ELEM_BYTES: tl.constexpr = 2, +): + """ + Gather KV data from ALL layers in a single kernel launch. + + Reads scattered from GPU KV cache, writes contiguous to pinned CPU. + O(1) extra GPU memory. Dtype-agnostic (copies raw bytes). + + Grid: (num_heads_to_gather, num_layers * 2) - one program per (head, layer_kv) + - program_id(0): head index (0 to num_heads_to_gather-1) + - program_id(1): layer_kv index (0 to num_layers*2-1), where even=K, odd=V + """ + head_id = tl.program_id(0) + layer_kv_id = tl.program_id(1) + + # Decode layer and K/V from combined index + layer_id = layer_kv_id // 2 + is_v = layer_kv_id % 2 # 0 = K, 1 = V + + # Load the data pointer for this layer's K or V buffer + # Use int8 or int16 based on element size for dtype-agnostic byte copying + if ELEM_BYTES == 1: + if is_v == 0: + src_base_ptr = tl.load(k_data_ptrs + layer_id).to( + tl.pointer_type(tl.int8) + ) + else: + src_base_ptr = tl.load(v_data_ptrs + layer_id).to( + tl.pointer_type(tl.int8) + ) + else: # ELEM_BYTES == 2 + if is_v == 0: + src_base_ptr = tl.load(k_data_ptrs + layer_id).to( + tl.pointer_type(tl.int16) + ) + else: + src_base_ptr = tl.load(v_data_ptrs + layer_id).to( + tl.pointer_type(tl.int16) + ) + + # Source head index (absolute in source KV cache) + src_head = head_start + head_id + + # Cast strides to int64 + src_slot_stride_i64 = src_slot_stride.to(tl.int64) + src_head_stride_i64 = src_head_stride.to(tl.int64) + out_head_stride_i64 = out_head_stride.to(tl.int64) + out_layer_stride_i64 = out_layer_stride.to(tl.int64) + out_kv_stride_i64 = out_kv_stride.to(tl.int64) + out_token_stride_i64 = out_token_stride.to(tl.int64) + src_head_i64 = src_head.to(tl.int64) + head_id_i64 = head_id.to(tl.int64) + layer_id_i64 = layer_id.to(tl.int64) + is_v_i64 = is_v.to(tl.int64) + + # Base output offset for this (head, layer, kv) + out_base = ( + head_id_i64 * out_head_stride_i64 + + layer_id_i64 * out_layer_stride_i64 + + is_v_i64 * out_kv_stride_i64 + ) + + # Process tokens in blocks + for token_block_start in range(0, num_tokens, BLOCK_TOKENS): + token_offsets = token_block_start + tl.arange(0, BLOCK_TOKENS) + token_mask = token_offsets < num_tokens + token_offsets_i64 = token_offsets.to(tl.int64) + + # Load slot indices for these tokens + slot_ids = tl.load(slot_indices_ptr + token_offsets, mask=token_mask, other=0) + slot_ids_i64 = slot_ids.to(tl.int64) + + # Process head_dim in blocks + for dim_start in range(0, head_dim, BLOCK_DIM): + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + dim_mask = dim_offsets < head_dim + dim_offsets_i64 = dim_offsets.to(tl.int64) + + # Compute source addresses in GPU KV cache + src_offsets = ( + slot_ids_i64[:, None] * src_slot_stride_i64 + + src_head_i64 * src_head_stride_i64 + + dim_offsets_i64[None, :] + ) + + # Load from GPU source + mask = token_mask[:, None] & dim_mask[None, :] + data = tl.load(src_base_ptr + src_offsets, mask=mask, other=0.0) + + # Compute output addresses + out_offsets = ( + out_base + + token_offsets_i64[:, None] * out_token_stride_i64 + + dim_offsets_i64[None, :] + ) + + # Store to output buffer + tl.store(output_ptr + out_offsets, data, mask=mask) + + +def gather_kv_to_pinned_all_layers( + k_data_ptrs: torch.Tensor, # [num_layers] uint64 tensor of K buffer pointers + v_data_ptrs: torch.Tensor, # [num_layers] uint64 tensor of V buffer pointers + slot_indices: torch.Tensor, # [num_tokens] on GPU + pinned_output: torch.Tensor, # pinned CPU buffer + head_start: int, + num_heads_to_gather: int, + num_layers: int, + head_dim: int, + src_slot_stride: int, # stride between slots in source (num_heads * head_dim) + src_head_stride: int, # stride between heads in source (head_dim) + kv_elem_bytes: int = None, # element size of KV cache (1 for fp8, 2 for fp16/bf16) +) -> None: + """ + Gather KV data from ALL layers using a SINGLE kernel launch. + + O(1) extra GPU memory - writes directly to pinned CPU memory. + + Args: + k_data_ptrs: Tensor of uint64 pointers to each layer's K buffer + v_data_ptrs: Tensor of uint64 pointers to each layer's V buffer + slot_indices: Tensor of slot indices to gather + pinned_output: Pinned CPU buffer to write to + head_start: First head index to gather + num_heads_to_gather: Number of heads to gather + num_layers: Number of layers + head_dim: Dimension of each head + src_slot_stride: Stride between slots in source buffers (num_heads * head_dim) + src_head_stride: Stride between heads in source buffers (head_dim) + kv_elem_bytes: Element size of KV cache in bytes. Must match pinned_output.element_size(). + + Output layout: [num_heads_to_gather, num_layers, 2, num_tokens, head_dim] (HEAD-FIRST) + """ + assert pinned_output.is_pinned(), "Output buffer must be pinned CPU memory" + assert slot_indices.is_cuda, "slot_indices must be on GPU" + assert k_data_ptrs.dtype == torch.uint64, "k_data_ptrs must be uint64" + assert v_data_ptrs.dtype == torch.uint64, "v_data_ptrs must be uint64" + + # Validate element size consistency between KV cache and pinned buffer + pinned_elem_bytes = pinned_output.element_size() + if kv_elem_bytes is not None: + assert pinned_elem_bytes == kv_elem_bytes, ( + f"KV cache element size ({kv_elem_bytes} bytes) does not match " + f"pinned buffer element size ({pinned_elem_bytes} bytes). " + f"The pinned buffer dtype must match the KV cache dtype." + ) + + num_tokens = slot_indices.shape[0] + + # Block sizes + BLOCK_TOKENS = 64 + BLOCK_DIM = min(64, triton.next_power_of_2(head_dim)) + + # HEAD-FIRST output layout strides (in elements) + out_head_stride = num_layers * 2 * num_tokens * head_dim + out_layer_stride = 2 * num_tokens * head_dim + out_kv_stride = num_tokens * head_dim + out_token_stride = head_dim + + # Grid: (num_heads_to_gather, num_layers * 2) + grid = (num_heads_to_gather, num_layers * 2) + + # Element size in bytes (1 for fp8, 2 for fp16/bf16) + elem_bytes = pinned_output.element_size() + + # View as int8/int16 for dtype-agnostic byte copying + # This ensures Triton pointer arithmetic matches our load/store types + if elem_bytes == 1: + output_view = pinned_output.view(torch.int8) + else: + output_view = pinned_output.view(torch.int16) + + _gather_kv_all_layers_kernel[grid]( + k_data_ptrs, + v_data_ptrs, + slot_indices, + output_view, + num_layers=num_layers, + num_tokens=num_tokens, + head_dim=head_dim, + head_start=head_start, + num_heads_to_gather=num_heads_to_gather, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + out_head_stride=out_head_stride, + out_layer_stride=out_layer_stride, + out_kv_stride=out_kv_stride, + out_token_stride=out_token_stride, + BLOCK_TOKENS=BLOCK_TOKENS, + BLOCK_DIM=BLOCK_DIM, + ELEM_BYTES=elem_bytes, + ) + + +@triton.jit +def _scatter_kv_all_layers_from_pinned_kernel( + # Pointers to all K buffers (tensor of uint64 pointers) + k_data_ptrs, + # Pointers to all V buffers (tensor of uint64 pointers) + v_data_ptrs, + # Slot indices to scatter to + slot_indices_ptr, + # Input pinned CPU buffer (HEAD-FIRST layout) + input_ptr, + # Dimensions + num_layers, + num_tokens, + head_dim: tl.constexpr, + # Head slicing params + head_start: tl.constexpr, + num_heads_to_scatter: tl.constexpr, + # Destination strides (in elements) - same for all layers + dst_slot_stride, + dst_head_stride, + # Input layout strides (HEAD-FIRST: [num_heads, num_layers, 2, num_tokens, head_dim]) + in_head_stride, # = num_layers * 2 * num_tokens * head_dim + in_layer_stride, # = 2 * num_tokens * head_dim + in_kv_stride, # = num_tokens * head_dim + in_token_stride, # = head_dim + # Block sizes + BLOCK_TOKENS: tl.constexpr, + BLOCK_DIM: tl.constexpr, + # Element size: 1 for fp8, 2 for fp16/bf16 + ELEM_BYTES: tl.constexpr = 2, +): + """ + Scatter KV data from pinned CPU to ALL GPU layers in a single kernel launch. + + Reads contiguous from pinned CPU, writes scattered to GPU KV cache. + O(1) extra GPU memory. Dtype-agnostic (copies raw bytes). + + Grid: (num_heads_to_scatter, num_layers * 2) - one program per (head, layer_kv) + - program_id(0): head index (0 to num_heads_to_scatter-1) + - program_id(1): layer_kv index (0 to num_layers*2-1), where even=K, odd=V + """ + head_id = tl.program_id(0) + layer_kv_id = tl.program_id(1) + + # Decode layer and K/V from combined index + layer_id = layer_kv_id // 2 + is_v = layer_kv_id % 2 # 0 = K, 1 = V + + # Load the data pointer for this layer's K or V buffer + # Use int8 or int16 based on element size for dtype-agnostic byte copying + if ELEM_BYTES == 1: + if is_v == 0: + dst_base_ptr = tl.load(k_data_ptrs + layer_id).to( + tl.pointer_type(tl.int8) + ) + else: + dst_base_ptr = tl.load(v_data_ptrs + layer_id).to( + tl.pointer_type(tl.int8) + ) + else: # ELEM_BYTES == 2 + if is_v == 0: + dst_base_ptr = tl.load(k_data_ptrs + layer_id).to( + tl.pointer_type(tl.int16) + ) + else: + dst_base_ptr = tl.load(v_data_ptrs + layer_id).to( + tl.pointer_type(tl.int16) + ) + + # Destination head index (absolute in destination KV cache) + dst_head = head_start + head_id + + # Cast strides to int64 + dst_slot_stride_i64 = dst_slot_stride.to(tl.int64) + dst_head_stride_i64 = dst_head_stride.to(tl.int64) + in_head_stride_i64 = in_head_stride.to(tl.int64) + in_layer_stride_i64 = in_layer_stride.to(tl.int64) + in_kv_stride_i64 = in_kv_stride.to(tl.int64) + in_token_stride_i64 = in_token_stride.to(tl.int64) + dst_head_i64 = dst_head.to(tl.int64) + head_id_i64 = head_id.to(tl.int64) + layer_id_i64 = layer_id.to(tl.int64) + is_v_i64 = is_v.to(tl.int64) + + # Base input offset for this (head, layer, kv) in HEAD-FIRST layout + in_base = ( + head_id_i64 * in_head_stride_i64 + + layer_id_i64 * in_layer_stride_i64 + + is_v_i64 * in_kv_stride_i64 + ) + + # Process tokens in blocks + for token_block_start in range(0, num_tokens, BLOCK_TOKENS): + token_offsets = token_block_start + tl.arange(0, BLOCK_TOKENS) + token_mask = token_offsets < num_tokens + token_offsets_i64 = token_offsets.to(tl.int64) + + # Load slot indices for these tokens + slot_ids = tl.load(slot_indices_ptr + token_offsets, mask=token_mask, other=0) + slot_ids_i64 = slot_ids.to(tl.int64) + + # Process head_dim in blocks + for dim_start in range(0, head_dim, BLOCK_DIM): + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + dim_mask = dim_offsets < head_dim + dim_offsets_i64 = dim_offsets.to(tl.int64) + + # Compute input addresses in pinned CPU buffer (HEAD-FIRST layout) + in_offsets = ( + in_base + + token_offsets_i64[:, None] * in_token_stride_i64 + + dim_offsets_i64[None, :] + ) + + # Load from pinned CPU buffer (zero-copy read over PCIe) + mask = token_mask[:, None] & dim_mask[None, :] + data = tl.load(input_ptr + in_offsets, mask=mask, other=0.0) + + # Compute destination addresses in GPU KV cache (scattered writes) + dst_offsets = ( + slot_ids_i64[:, None] * dst_slot_stride_i64 + + dst_head_i64 * dst_head_stride_i64 + + dim_offsets_i64[None, :] + ) + + # Store to GPU KV cache + tl.store(dst_base_ptr + dst_offsets, data, mask=mask) + + +def scatter_kv_with_staging_all_layers( + pinned_input: torch.Tensor, + k_data_ptrs: torch.Tensor, # [num_layers] uint64 tensor of K buffer pointers + v_data_ptrs: torch.Tensor, # [num_layers] uint64 tensor of V buffer pointers + slot_indices: torch.Tensor, + head_start: int, + num_heads_to_scatter: int, + num_layers: int, + head_dim: int, + dst_slot_stride: int, # stride between slots in dest (num_heads * head_dim) + dst_head_stride: int, # stride between heads in dest (head_dim) + kv_elem_bytes: int = None, # element size of KV cache (1 for fp8, 2 for fp16/bf16) +) -> None: + """ + Scatter KV data to ALL layers using a SINGLE kernel launch. + + O(1) extra GPU memory - reads directly from pinned CPU memory. + + Args: + pinned_input: Pinned CPU buffer in HEAD-FIRST layout + k_data_ptrs: Tensor of uint64 pointers to each layer's K buffer + v_data_ptrs: Tensor of uint64 pointers to each layer's V buffer + slot_indices: Tensor of slot indices to scatter to + head_start: First head index to scatter to + num_heads_to_scatter: Number of heads to scatter + num_layers: Number of layers + head_dim: Dimension of each head + dst_slot_stride: Stride between slots in dest buffers + dst_head_stride: Stride between heads in dest buffers + kv_elem_bytes: Element size of KV cache in bytes. Must match pinned_input.element_size(). + + Input layout: [num_heads_to_scatter, num_layers, 2, num_tokens, head_dim] (HEAD-FIRST) + """ + assert pinned_input.is_pinned(), "Input buffer must be pinned CPU memory" + assert slot_indices.is_cuda, "slot_indices must be on GPU" + assert k_data_ptrs.dtype == torch.uint64, "k_data_ptrs must be uint64" + assert v_data_ptrs.dtype == torch.uint64, "v_data_ptrs must be uint64" + + # Validate element size consistency between KV cache and pinned buffer + pinned_elem_bytes = pinned_input.element_size() + if kv_elem_bytes is not None: + assert pinned_elem_bytes == kv_elem_bytes, ( + f"KV cache element size ({kv_elem_bytes} bytes) does not match " + f"pinned buffer element size ({pinned_elem_bytes} bytes). " + f"The pinned buffer dtype must match the KV cache dtype." + ) + + num_tokens = slot_indices.shape[0] + + # Block sizes + BLOCK_TOKENS = 64 + BLOCK_DIM = min(64, triton.next_power_of_2(head_dim)) + + # HEAD-FIRST input layout: [num_heads_to_scatter, num_layers, 2, num_tokens, head_dim] + # Strides for this layout (in elements): + in_head_stride = num_layers * 2 * num_tokens * head_dim + in_layer_stride = 2 * num_tokens * head_dim + in_kv_stride = num_tokens * head_dim + in_token_stride = head_dim + + # Grid: (num_heads_to_scatter, num_layers * 2) + grid = (num_heads_to_scatter, num_layers * 2) + + # Element size in bytes (1 for fp8, 2 for fp16/bf16) + elem_bytes = pinned_input.element_size() + + # View as int8/int16 for dtype-agnostic byte copying + # This ensures Triton pointer arithmetic matches our load/store types + if elem_bytes == 1: + input_view = pinned_input.view(torch.int8) + else: + input_view = pinned_input.view(torch.int16) + + _scatter_kv_all_layers_from_pinned_kernel[grid]( + k_data_ptrs, + v_data_ptrs, + slot_indices, + input_view, + num_layers=num_layers, + num_tokens=num_tokens, + head_dim=head_dim, + head_start=head_start, + num_heads_to_scatter=num_heads_to_scatter, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + in_head_stride=in_head_stride, + in_layer_stride=in_layer_stride, + in_kv_stride=in_kv_stride, + in_token_stride=in_token_stride, + BLOCK_TOKENS=BLOCK_TOKENS, + BLOCK_DIM=BLOCK_DIM, + ELEM_BYTES=elem_bytes, + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 065ce26baeb8..9e8723b9ec3d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -668,6 +668,10 @@ class ServerArgs: num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD # FIXME: hack to reduce ITL when decode bs is small disaggregation_decode_polling_interval: int = 1 + # Use pinned CPU buffer for NIXL KV transfers (Triton gather/scatter + single NIXL transfer) + nixl_use_cpu_buffer: bool = False + # Total size of pinned CPU buffer for NIXL KV transfers (GB) + nixl_cpu_buffer_size_gb: float = 16.0 # Encode prefill disaggregation encoder_only: bool = False @@ -4981,6 +4985,21 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.disaggregation_decode_polling_interval, help="The interval to poll requests in decode server. Can be set to >1 to reduce the overhead of this.", ) + parser.add_argument( + "--nixl-use-cpu-buffer", + action="store_true", + default=ServerArgs.nixl_use_cpu_buffer, + help="Use pinned CPU buffer for NIXL KV transfers. Uses Triton gather/scatter " + "kernels with a single NIXL transfer, reducing descriptor count from " + "O(tokens * layers) to O(1).", + ) + parser.add_argument( + "--nixl-cpu-buffer-size-gb", + type=float, + default=ServerArgs.nixl_cpu_buffer_size_gb, + help="Total size of pinned CPU buffer for NIXL KV transfers (GB). " + "Should be sized for sum of max concurrent transfers. Default is 16.0.", + ) # Encode prefill disaggregation parser.add_argument( diff --git a/sgl-kernel/benchmark/bench_kv_transfer.py b/sgl-kernel/benchmark/bench_kv_transfer.py new file mode 100644 index 000000000000..14386036f73f --- /dev/null +++ b/sgl-kernel/benchmark/bench_kv_transfer.py @@ -0,0 +1,363 @@ +""" +Benchmark for KV transfer Triton kernels. + +Measures bandwidth achieved when gathering/scattering KV cache data +between GPU and pinned CPU memory. +""" + +import argparse + +import torch + +from sglang.srt.layers.attention.triton_ops.kv_transfer import ( + gather_kv_to_pinned_all_layers, + scatter_kv_with_staging_all_layers, +) + + +def create_pointer_tensors(k_buffers, v_buffers): + """Helper to create pointer tensors and get strides.""" + k_data_ptrs = torch.tensor( + [x.data_ptr() for x in k_buffers], dtype=torch.uint64, device="cuda" + ) + v_data_ptrs = torch.tensor( + [x.data_ptr() for x in v_buffers], dtype=torch.uint64, device="cuda" + ) + slot_stride = k_buffers[0].stride(0) + head_stride = k_buffers[0].stride(1) + return k_data_ptrs, v_data_ptrs, slot_stride, head_stride + + +def benchmark_cuda_memcpy( + size_mb: float, + dtype: torch.dtype, + warmup: int = 10, + rep: int = 100, +) -> dict: + """ + Benchmark raw CUDA memcpy in both directions. + Returns bandwidth for D2H (GPU -> CPU) and H2D (CPU -> GPU). + """ + bytes_per_element = 2 if dtype in (torch.float16, torch.bfloat16) else 4 + num_elements = int(size_mb * 1e6 / bytes_per_element) + total_bytes = num_elements * bytes_per_element + + gpu_tensor = torch.randn(num_elements, dtype=dtype, device="cuda") + cpu_tensor = torch.empty(num_elements, dtype=dtype, device="cpu", pin_memory=True) + + # D2H benchmark + for _ in range(warmup): + cpu_tensor.copy_(gpu_tensor, non_blocking=False) + torch.cuda.synchronize() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + + for i in range(rep): + start_events[i].record() + cpu_tensor.copy_(gpu_tensor, non_blocking=False) + end_events[i].record() + torch.cuda.synchronize() + + d2h_times_ms = [start_events[i].elapsed_time(end_events[i]) for i in range(rep)] + d2h_avg_time_ms = sum(d2h_times_ms) / len(d2h_times_ms) + d2h_avg_bw = (total_bytes / 1e9) / (d2h_avg_time_ms / 1000) + + # H2D benchmark + for _ in range(warmup): + gpu_tensor.copy_(cpu_tensor, non_blocking=False) + torch.cuda.synchronize() + + for i in range(rep): + start_events[i].record() + gpu_tensor.copy_(cpu_tensor, non_blocking=False) + end_events[i].record() + torch.cuda.synchronize() + + h2d_times_ms = [start_events[i].elapsed_time(end_events[i]) for i in range(rep)] + h2d_avg_time_ms = sum(h2d_times_ms) / len(h2d_times_ms) + h2d_avg_bw = (total_bytes / 1e9) / (h2d_avg_time_ms / 1000) + + del gpu_tensor, cpu_tensor + torch.cuda.empty_cache() + + return { + "size_mb": size_mb, + "d2h_avg_time_ms": d2h_avg_time_ms, + "d2h_avg_bandwidth_gbs": d2h_avg_bw, + "h2d_avg_time_ms": h2d_avg_time_ms, + "h2d_avg_bandwidth_gbs": h2d_avg_bw, + } + + +def benchmark_kv_transfer( + num_layers: int, + num_tokens: int, + num_heads: int, + head_dim: int, + total_slots: int, + dtype: torch.dtype = torch.float16, + warmup: int = 5, + rep: int = 20, +) -> dict: + """ + Benchmark gather and scatter kernels with configurable sizes. + """ + bytes_per_element = 2 if dtype in (torch.float16, torch.bfloat16) else 4 + transfer_bytes = num_layers * 2 * num_tokens * num_heads * head_dim * bytes_per_element + + print(f"\n Configuration:") + print(f" num_layers={num_layers}, num_tokens={num_tokens}") + print(f" num_heads={num_heads}, head_dim={head_dim}") + print(f" total_slots={total_slots:,}") + print(f" transfer size = {transfer_bytes / 1e9:.2f} GB ({transfer_bytes / 1e6:.1f} MB)") + + # Create KV buffers + print(f"\n Allocating KV pool...") + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + print(f" Pool allocated. Free GPU memory: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB") + + # Create slot indices - random access pattern + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + # Allocate pinned buffer + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_buffer = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + # Contiguous GPU buffer for baselines + contiguous_gpu = torch.randn(output_size, dtype=dtype, device="cuda") + + results = {} + + # ========================================================================= + # D2H memcpy baseline + # ========================================================================= + print("\n Benchmarking D2H memcpy (GPU -> pinned CPU)...") + for _ in range(warmup): + pinned_buffer.copy_(contiguous_gpu, non_blocking=False) + torch.cuda.synchronize() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + for i in range(rep): + start_events[i].record() + pinned_buffer.copy_(contiguous_gpu, non_blocking=False) + end_events[i].record() + torch.cuda.synchronize() + + times_ms = [start_events[i].elapsed_time(end_events[i]) for i in range(rep)] + d2h_time_ms = sum(times_ms) / len(times_ms) + d2h_bw = (transfer_bytes / 1e9) / (d2h_time_ms / 1000) + results["d2h_memcpy"] = {"time_ms": d2h_time_ms, "bandwidth_gbs": d2h_bw} + print(f" D2H memcpy: {d2h_time_ms:.1f} ms, {d2h_bw:.2f} GB/s") + + # ========================================================================= + # H2D memcpy baseline + # ========================================================================= + print("\n Benchmarking H2D memcpy (pinned CPU -> GPU)...") + for _ in range(warmup): + contiguous_gpu.copy_(pinned_buffer, non_blocking=False) + torch.cuda.synchronize() + + for i in range(rep): + start_events[i].record() + contiguous_gpu.copy_(pinned_buffer, non_blocking=False) + end_events[i].record() + torch.cuda.synchronize() + + times_ms = [start_events[i].elapsed_time(end_events[i]) for i in range(rep)] + h2d_time_ms = sum(times_ms) / len(times_ms) + h2d_bw = (transfer_bytes / 1e9) / (h2d_time_ms / 1000) + results["h2d_memcpy"] = {"time_ms": h2d_time_ms, "bandwidth_gbs": h2d_bw} + print(f" H2D memcpy: {h2d_time_ms:.1f} ms, {h2d_bw:.2f} GB/s") + + del contiguous_gpu + + # ========================================================================= + # Gather benchmark + # ========================================================================= + print("\n Benchmarking gather (scattered GPU -> pinned CPU)...") + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + for _ in range(warmup): + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_buffer, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + for i in range(rep): + start_events[i].record() + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_buffer, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + end_events[i].record() + torch.cuda.synchronize() + + times_ms = [start_events[i].elapsed_time(end_events[i]) for i in range(rep)] + gather_time_ms = sum(times_ms) / len(times_ms) + gather_bw = (transfer_bytes / 1e9) / (gather_time_ms / 1000) + results["gather"] = {"time_ms": gather_time_ms, "bandwidth_gbs": gather_bw} + print(f" Gather: {gather_time_ms:.1f} ms, {gather_bw:.2f} GB/s ({gather_bw/d2h_bw*100:.0f}% of D2H)") + + # ========================================================================= + # Scatter benchmark + # ========================================================================= + print("\n Benchmarking scatter (pinned CPU -> scattered GPU)...") + + pinned_input = torch.randn(output_size, dtype=dtype, device="cpu", pin_memory=True) + + k_buffers_dst = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers_dst = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + k_data_ptrs_dst, v_data_ptrs_dst, dst_slot_stride, dst_head_stride = create_pointer_tensors(k_buffers_dst, v_buffers_dst) + + for _ in range(warmup): + scatter_kv_with_staging_all_layers( + pinned_input=pinned_input, + k_data_ptrs=k_data_ptrs_dst, + v_data_ptrs=v_data_ptrs_dst, + slot_indices=slot_indices, + head_start=0, + num_heads_to_scatter=num_heads, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + ) + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + for i in range(rep): + start_events[i].record() + scatter_kv_with_staging_all_layers( + pinned_input=pinned_input, + k_data_ptrs=k_data_ptrs_dst, + v_data_ptrs=v_data_ptrs_dst, + slot_indices=slot_indices, + head_start=0, + num_heads_to_scatter=num_heads, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + ) + end_events[i].record() + torch.cuda.synchronize() + + times_ms = [start_events[i].elapsed_time(end_events[i]) for i in range(rep)] + scatter_time_ms = sum(times_ms) / len(times_ms) + scatter_bw = (transfer_bytes / 1e9) / (scatter_time_ms / 1000) + results["scatter"] = {"time_ms": scatter_time_ms, "bandwidth_gbs": scatter_bw} + print(f" Scatter: {scatter_time_ms:.1f} ms, {scatter_bw:.2f} GB/s ({scatter_bw/h2d_bw*100:.0f}% of H2D)") + + # Clean up + del k_buffers, v_buffers, k_buffers_dst, v_buffers_dst + del pinned_buffer, pinned_input, slot_indices + torch.cuda.empty_cache() + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark KV transfer kernels") + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--rep", type=int, default=20, help="Benchmark repetitions") + parser.add_argument("--num-layers", type=int, default=92, help="Number of layers") + parser.add_argument("--num-tokens", type=int, default=32768, help="Number of tokens") + parser.add_argument("--num-heads", type=int, default=8, help="Number of heads") + parser.add_argument("--head-dim", type=int, default=128, help="Head dimension") + parser.add_argument("--total-slots", type=int, default=None, help="Total slots (default: 4x num_tokens)") + args = parser.parse_args() + + print("=" * 80) + print(" KV Transfer Kernel Benchmark") + print(" Single-kernel gather/scatter (O(1) GPU memory overhead)") + print("=" * 80) + + # Get GPU info + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + print(f"\nGPU: {props.name}") + + total_slots = args.total_slots or (args.num_tokens * 4) + transfer_size_gb = args.num_layers * 2 * args.num_tokens * args.num_heads * args.head_dim * 2 / 1e9 + + print("\n" + "=" * 80) + print(" KV Transfer Benchmark") + print(f" {args.num_layers} layers, {args.num_tokens} tokens, {args.num_heads} heads, {args.head_dim} head_dim") + print(f" Transfer size: {transfer_size_gb:.2f} GB") + print("=" * 80) + + results = benchmark_kv_transfer( + num_layers=args.num_layers, + num_tokens=args.num_tokens, + num_heads=args.num_heads, + head_dim=args.head_dim, + total_slots=total_slots, + warmup=args.warmup, + rep=args.rep, + ) + + # Summary + print("\n" + "=" * 80) + print(f" Summary ({transfer_size_gb:.2f} GB transfers)") + print("=" * 80) + + d2h_bw = results["d2h_memcpy"]["bandwidth_gbs"] + h2d_bw = results["h2d_memcpy"]["bandwidth_gbs"] + + print(f"\n PCIe Baselines (contiguous transfers):") + print(f" {'Method':<25} {'Time (ms)':<12} {'BW (GB/s)':<12}") + print("-" * 50) + print(f" {'D2H memcpy (GPU->CPU)':<25} {results['d2h_memcpy']['time_ms']:<12.1f} {d2h_bw:<12.2f}") + print(f" {'H2D memcpy (CPU->GPU)':<25} {results['h2d_memcpy']['time_ms']:<12.1f} {h2d_bw:<12.2f}") + + print(f"\n Kernel Performance:") + print(f" {'Method':<25} {'Time (ms)':<12} {'BW (GB/s)':<12} {'Efficiency':<12}") + print("-" * 65) + gather_eff = results['gather']['bandwidth_gbs'] / d2h_bw * 100 + scatter_eff = results['scatter']['bandwidth_gbs'] / h2d_bw * 100 + print(f" {'Gather (GPU->CPU)':<25} {results['gather']['time_ms']:<12.1f} {results['gather']['bandwidth_gbs']:<12.2f} {gather_eff:.0f}% of D2H") + print(f" {'Scatter (CPU->GPU)':<25} {results['scatter']['time_ms']:<12.1f} {results['scatter']['bandwidth_gbs']:<12.2f} {scatter_eff:.0f}% of H2D") + + print(f"\n Key features:") + print(f" - Single kernel launch for all {args.num_layers} layers") + print(f" - O(1) extra GPU memory (just {args.num_layers * 16 / 1024:.1f} KB for pointer tensors)") + print(f" - No staging buffers needed") + + +if __name__ == "__main__": + main() diff --git a/sgl-kernel/tests/test_kv_transfer.py b/sgl-kernel/tests/test_kv_transfer.py new file mode 100644 index 000000000000..041cd5efddf5 --- /dev/null +++ b/sgl-kernel/tests/test_kv_transfer.py @@ -0,0 +1,1025 @@ +""" +Tests for KV transfer Triton kernels. + +Tests the gather and scatter kernels for KV cache transfers: +- gather_kv_to_pinned_all_layers: GPU -> pinned CPU (device to host) +- scatter_kv_with_staging_all_layers: pinned CPU -> GPU (host to device) +""" + +import pytest +import torch + +from sglang.srt.layers.attention.triton_ops.kv_transfer import ( + gather_kv_to_pinned_all_layers, + scatter_kv_with_staging_all_layers, +) + + +def reference_gather_kv( + k_buffers: list[torch.Tensor], + v_buffers: list[torch.Tensor], + slot_indices: torch.Tensor, + head_start: int, + num_heads_to_gather: int, +) -> torch.Tensor: + """ + Reference implementation of KV gather using PyTorch operations. + + Returns tensor of shape [num_heads_to_gather, num_layers, 2, num_tokens, head_dim] + This is the HEAD-FIRST layout for easy head slicing in mixed-TP transfers. + """ + num_layers = len(k_buffers) + num_tokens = slot_indices.shape[0] + head_dim = k_buffers[0].shape[2] + dtype = k_buffers[0].dtype + + output = torch.zeros( + (num_heads_to_gather, num_layers, 2, num_tokens, head_dim), + dtype=dtype, + device=k_buffers[0].device, + ) + + head_end = head_start + num_heads_to_gather + + for layer_idx in range(num_layers): + k_data = k_buffers[layer_idx][slot_indices, head_start:head_end, :] + v_data = v_buffers[layer_idx][slot_indices, head_start:head_end, :] + + for h in range(num_heads_to_gather): + output[h, layer_idx, 0] = k_data[:, h, :] + output[h, layer_idx, 1] = v_data[:, h, :] + + return output + + +def reference_scatter_kv( + pinned_input: torch.Tensor, + slot_indices: torch.Tensor, + num_layers: int, + num_heads_to_scatter: int, + head_dim: int, + total_slots: int, + num_heads: int, + head_start: int, + dtype: torch.dtype, +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """ + Reference implementation of KV scatter using PyTorch operations. + + Input is in HEAD-FIRST layout: [num_heads_to_scatter, num_layers, 2, num_tokens, head_dim] + Returns (k_buffers, v_buffers) with data scattered to the specified slots. + """ + num_tokens = slot_indices.shape[0] + + k_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + input_shaped = pinned_input.view(num_heads_to_scatter, num_layers, 2, num_tokens, head_dim) + + for layer_idx in range(num_layers): + for h in range(num_heads_to_scatter): + k_buffers[layer_idx][slot_indices, head_start + h, :] = input_shaped[h, layer_idx, 0].cuda() + v_buffers[layer_idx][slot_indices, head_start + h, :] = input_shaped[h, layer_idx, 1].cuda() + + return k_buffers, v_buffers + + +def create_pointer_tensors(k_buffers, v_buffers): + """Helper to create pointer tensors and get strides.""" + k_data_ptrs = torch.tensor( + [x.data_ptr() for x in k_buffers], dtype=torch.uint64, device="cuda" + ) + v_data_ptrs = torch.tensor( + [x.data_ptr() for x in v_buffers], dtype=torch.uint64, device="cuda" + ) + slot_stride = k_buffers[0].stride(0) + head_stride = k_buffers[0].stride(1) + return k_data_ptrs, v_data_ptrs, slot_stride, head_stride + + +# ============================================================================= +# Gather Tests (Device -> Host) +# ============================================================================= + + +@pytest.mark.parametrize("num_layers", [1, 4, 32]) +@pytest.mark.parametrize("num_tokens", [1, 64, 512]) +@pytest.mark.parametrize("num_heads", [8, 32]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gather_kv_full_heads(num_layers, num_tokens, num_heads, head_dim, dtype): + """Test gathering all heads (no slicing).""" + total_slots = 1024 + + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_output = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_output, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + + expected = reference_gather_kv( + k_buffers, v_buffers, slot_indices.long(), + head_start=0, num_heads_to_gather=num_heads + ) + + actual = pinned_output.view(num_heads, num_layers, 2, num_tokens, head_dim) + torch.testing.assert_close(actual, expected.cpu(), rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("num_heads", [8, 32]) +@pytest.mark.parametrize("head_start,num_heads_to_gather", [(0, 4), (4, 4), (0, 8), (8, 8)]) +def test_gather_kv_head_slicing(num_heads, head_start, num_heads_to_gather): + """Test gathering a subset of heads (for mixed-TP).""" + if head_start + num_heads_to_gather > num_heads: + pytest.skip("head slice exceeds num_heads") + + num_layers = 4 + num_tokens = 128 + head_dim = 128 + total_slots = 512 + dtype = torch.float16 + + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + output_size = num_layers * 2 * num_tokens * num_heads_to_gather * head_dim + pinned_output = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_output, + head_start=head_start, + num_heads_to_gather=num_heads_to_gather, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + + expected = reference_gather_kv( + k_buffers, v_buffers, slot_indices.long(), + head_start=head_start, num_heads_to_gather=num_heads_to_gather + ) + + actual = pinned_output.view(num_heads_to_gather, num_layers, 2, num_tokens, head_dim) + torch.testing.assert_close(actual, expected.cpu(), rtol=1e-3, atol=1e-3) + + +def test_gather_kv_contiguous_indices(): + """Test with contiguous slot indices (best case for memory access).""" + num_layers = 4 + num_tokens = 256 + num_heads = 8 + head_dim = 128 + total_slots = 1024 + dtype = torch.float16 + + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + start_slot = 100 + slot_indices = torch.arange( + start_slot, start_slot + num_tokens, device="cuda", dtype=torch.int32 + ) + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_output = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_output, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + + expected = reference_gather_kv( + k_buffers, v_buffers, slot_indices.long(), + head_start=0, num_heads_to_gather=num_heads + ) + + actual = pinned_output.view(num_heads, num_layers, 2, num_tokens, head_dim) + torch.testing.assert_close(actual, expected.cpu(), rtol=1e-3, atol=1e-3) + + +def test_gather_kv_non_pinned_raises(): + """Test that non-pinned output raises an error.""" + k_buffers = [torch.randn(64, 4, 32, dtype=torch.float16, device="cuda")] + v_buffers = [torch.randn(64, 4, 32, dtype=torch.float16, device="cuda")] + slot_indices = torch.tensor([0, 1, 2, 3], dtype=torch.int32, device="cuda") + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + non_pinned_output = torch.empty(1 * 2 * 4 * 4 * 32, dtype=torch.float16, device="cpu") + + with pytest.raises(AssertionError, match="pinned"): + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=non_pinned_output, + head_start=0, + num_heads_to_gather=4, + num_layers=1, + head_dim=32, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + + +@pytest.mark.parametrize("total_slots", [100_000, 500_000]) +@pytest.mark.parametrize("num_tokens", [128, 1024, 4096]) +def test_gather_kv_large_pool_sparse_access(total_slots, num_tokens): + """Test sparse access pattern in a large KV cache pool.""" + if num_tokens > total_slots: + pytest.skip("num_tokens exceeds total_slots") + + num_layers = 4 + num_heads = 8 + head_dim = 128 + dtype = torch.float16 + + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_output = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_output, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + + expected = reference_gather_kv( + k_buffers, v_buffers, slot_indices.long(), + head_start=0, num_heads_to_gather=num_heads + ) + + actual = pinned_output.view(num_heads, num_layers, 2, num_tokens, head_dim) + torch.testing.assert_close(actual, expected.cpu(), rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("num_tokens", [8192, 16384]) +def test_gather_kv_long_sequence(num_tokens): + """Test with long sequences (many tokens to gather).""" + num_layers = 4 + num_heads = 8 + head_dim = 128 + total_slots = num_tokens + 10_000 + dtype = torch.float16 + + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_output = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_output, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + + expected = reference_gather_kv( + k_buffers, v_buffers, slot_indices.long(), + head_start=0, num_heads_to_gather=num_heads + ) + + actual = pinned_output.view(num_heads, num_layers, 2, num_tokens, head_dim) + torch.testing.assert_close(actual, expected.cpu(), rtol=1e-3, atol=1e-3) + + +# ============================================================================= +# Scatter Tests (Host -> Device) +# ============================================================================= + + +@pytest.mark.parametrize("num_layers", [1, 4, 32]) +@pytest.mark.parametrize("num_tokens", [1, 64, 512]) +@pytest.mark.parametrize("num_heads", [8, 32]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_scatter_kv_full_heads(num_layers, num_tokens, num_heads, head_dim, dtype): + """Test scattering all heads (no slicing).""" + total_slots = 1024 + + input_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_input = torch.randn(input_size, dtype=dtype, device="cpu", pin_memory=True) + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + k_data_ptrs, v_data_ptrs, dst_slot_stride, dst_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + scatter_kv_with_staging_all_layers( + pinned_input=pinned_input, + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + head_start=0, + num_heads_to_scatter=num_heads, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + ) + + expected_k, expected_v = reference_scatter_kv( + pinned_input, slot_indices.long(), + num_layers=num_layers, + num_heads_to_scatter=num_heads, + head_dim=head_dim, + total_slots=total_slots, + num_heads=num_heads, + head_start=0, + dtype=dtype, + ) + + for layer_idx in range(num_layers): + torch.testing.assert_close(k_buffers[layer_idx], expected_k[layer_idx], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(v_buffers[layer_idx], expected_v[layer_idx], rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("head_start,num_heads_to_scatter", [ + (0, 4), + (4, 4), + (2, 4), + (0, 8), +]) +def test_scatter_kv_head_slicing(head_start, num_heads_to_scatter): + """Test scattering a subset of heads (for mixed-TP).""" + num_layers = 4 + num_tokens = 256 + num_heads = 8 + head_dim = 128 + total_slots = 2048 + dtype = torch.float16 + + input_size = num_layers * 2 * num_tokens * num_heads_to_scatter * head_dim + pinned_input = torch.randn(input_size, dtype=dtype, device="cpu", pin_memory=True) + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + k_data_ptrs, v_data_ptrs, dst_slot_stride, dst_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + scatter_kv_with_staging_all_layers( + pinned_input=pinned_input, + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + head_start=head_start, + num_heads_to_scatter=num_heads_to_scatter, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + ) + + expected_k, expected_v = reference_scatter_kv( + pinned_input, slot_indices.long(), + num_layers=num_layers, + num_heads_to_scatter=num_heads_to_scatter, + head_dim=head_dim, + total_slots=total_slots, + num_heads=num_heads, + head_start=head_start, + dtype=dtype, + ) + + for layer_idx in range(num_layers): + torch.testing.assert_close(k_buffers[layer_idx], expected_k[layer_idx], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(v_buffers[layer_idx], expected_v[layer_idx], rtol=1e-3, atol=1e-3) + + +def test_scatter_kv_roundtrip(): + """ + Test that gather followed by scatter is identity (roundtrip test). + + Data gathered from GPU to pinned CPU should be correctly scattered back to GPU. + """ + num_layers = 32 + num_tokens = 1024 + num_heads = 8 + head_dim = 128 + total_slots = 10_000 + dtype = torch.float16 + + # Source KV buffers + k_buffers_src = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers_src = [ + torch.randn(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_data_ptrs_src, v_data_ptrs_src, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers_src, v_buffers_src) + + # Gather + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_buffer = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs_src, + v_data_ptrs=v_data_ptrs_src, + slot_indices=slot_indices, + pinned_output=pinned_buffer, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + ) + + # Destination KV buffers + k_buffers_dst = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers_dst = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + k_data_ptrs_dst, v_data_ptrs_dst, dst_slot_stride, dst_head_stride = create_pointer_tensors(k_buffers_dst, v_buffers_dst) + + # Scatter + scatter_kv_with_staging_all_layers( + pinned_input=pinned_buffer, + k_data_ptrs=k_data_ptrs_dst, + v_data_ptrs=v_data_ptrs_dst, + slot_indices=slot_indices, + head_start=0, + num_heads_to_scatter=num_heads, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + ) + + # Verify roundtrip + for layer_idx in range(num_layers): + src_k = k_buffers_src[layer_idx][slot_indices.long()] + dst_k = k_buffers_dst[layer_idx][slot_indices.long()] + torch.testing.assert_close(dst_k, src_k, rtol=1e-3, atol=1e-3) + + src_v = v_buffers_src[layer_idx][slot_indices.long()] + dst_v = v_buffers_dst[layer_idx][slot_indices.long()] + torch.testing.assert_close(dst_v, src_v, rtol=1e-3, atol=1e-3) + + +def test_scatter_kv_large_pool(): + """Test scatter with large pool (high fragmentation).""" + torch.cuda.empty_cache() + + num_layers = 32 + num_tokens = 2048 + num_heads = 8 + head_dim = 128 + total_slots = 100_000 + dtype = torch.float16 + + input_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_input = torch.randn(input_size, dtype=dtype, device="cpu", pin_memory=True) + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + k_data_ptrs, v_data_ptrs, dst_slot_stride, dst_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + scatter_kv_with_staging_all_layers( + pinned_input=pinned_input, + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + head_start=0, + num_heads_to_scatter=num_heads, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + ) + + expected_k, expected_v = reference_scatter_kv( + pinned_input, slot_indices.long(), + num_layers=num_layers, + num_heads_to_scatter=num_heads, + head_dim=head_dim, + total_slots=total_slots, + num_heads=num_heads, + head_start=0, + dtype=dtype, + ) + + for layer_idx in range(num_layers): + actual_k = k_buffers[layer_idx][slot_indices.long()] + expected_k_at_slots = expected_k[layer_idx][slot_indices.long()] + torch.testing.assert_close(actual_k, expected_k_at_slots, rtol=1e-3, atol=1e-3) + + actual_v = v_buffers[layer_idx][slot_indices.long()] + expected_v_at_slots = expected_v[layer_idx][slot_indices.long()] + torch.testing.assert_close(actual_v, expected_v_at_slots, rtol=1e-3, atol=1e-3) + + +# ============================================================================= +# FP8 Dtype Tests +# ============================================================================= + + +def has_fp8_support(): + """Check if the current GPU supports FP8.""" + if not torch.cuda.is_available(): + return False + capability = torch.cuda.get_device_capability() + # FP8 requires SM89+ (Ada Lovelace / Hopper) + return capability[0] >= 9 or (capability[0] == 8 and capability[1] >= 9) + + +@pytest.mark.skipif(not has_fp8_support(), reason="FP8 requires SM89+ GPU") +@pytest.mark.parametrize("num_layers", [1, 4]) +@pytest.mark.parametrize("num_tokens", [64, 256]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_dim", [64, 128]) +def test_gather_kv_fp8(num_layers, num_tokens, num_heads, head_dim): + """Test gathering with FP8 dtype (e4m3fn).""" + total_slots = 1024 + dtype = torch.float8_e4m3fn + + # Create FP8 KV buffers + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda").to(dtype) + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda").to(dtype) + for _ in range(num_layers) + ] + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_output = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + # Pass kv_elem_bytes to trigger validation + kv_elem_bytes = k_buffers[0].element_size() + assert kv_elem_bytes == 1, "FP8 should be 1 byte" + + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_output, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + kv_elem_bytes=kv_elem_bytes, + ) + + # Verify by comparing with reference (convert to float16 for comparison) + expected = reference_gather_kv( + k_buffers, v_buffers, slot_indices.long(), + head_start=0, num_heads_to_gather=num_heads + ) + + actual = pinned_output.view(num_heads, num_layers, 2, num_tokens, head_dim) + # Compare as raw bytes since FP8 doesn't support direct comparison + torch.testing.assert_close( + actual.view(torch.int8), + expected.cpu().view(torch.int8), + ) + + +@pytest.mark.skipif(not has_fp8_support(), reason="FP8 requires SM89+ GPU") +@pytest.mark.parametrize("num_layers", [1, 4]) +@pytest.mark.parametrize("num_tokens", [64, 256]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_dim", [64, 128]) +def test_scatter_kv_fp8(num_layers, num_tokens, num_heads, head_dim): + """Test scattering with FP8 dtype (e4m3fn).""" + total_slots = 1024 + dtype = torch.float8_e4m3fn + + input_size = num_layers * 2 * num_tokens * num_heads * head_dim + # Create FP8 pinned input + pinned_input = torch.randn(input_size, dtype=torch.float16, device="cpu").to(dtype).pin_memory() + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + k_data_ptrs, v_data_ptrs, dst_slot_stride, dst_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + kv_elem_bytes = k_buffers[0].element_size() + assert kv_elem_bytes == 1, "FP8 should be 1 byte" + + scatter_kv_with_staging_all_layers( + pinned_input=pinned_input, + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + head_start=0, + num_heads_to_scatter=num_heads, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + kv_elem_bytes=kv_elem_bytes, + ) + + # Verify by reference implementation (using float16 internally) + expected_k, expected_v = reference_scatter_kv( + pinned_input.to(torch.float16), slot_indices.long(), + num_layers=num_layers, + num_heads_to_scatter=num_heads, + head_dim=head_dim, + total_slots=total_slots, + num_heads=num_heads, + head_start=0, + dtype=torch.float16, + ) + + for layer_idx in range(num_layers): + # Compare as raw bytes + actual_k = k_buffers[layer_idx].view(torch.int8) + expected_k_bytes = expected_k[layer_idx].to(dtype).view(torch.int8) + torch.testing.assert_close(actual_k, expected_k_bytes) + + actual_v = v_buffers[layer_idx].view(torch.int8) + expected_v_bytes = expected_v[layer_idx].to(dtype).view(torch.int8) + torch.testing.assert_close(actual_v, expected_v_bytes) + + +@pytest.mark.skipif(not has_fp8_support(), reason="FP8 requires SM89+ GPU") +def test_gather_scatter_fp8_roundtrip(): + """Test that FP8 gather followed by scatter is identity (roundtrip).""" + num_layers = 4 + num_tokens = 512 + num_heads = 8 + head_dim = 128 + total_slots = 2048 + dtype = torch.float8_e4m3fn + + # Source KV buffers in FP8 + k_buffers_src = [ + torch.randn(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda").to(dtype) + for _ in range(num_layers) + ] + v_buffers_src = [ + torch.randn(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda").to(dtype) + for _ in range(num_layers) + ] + + slot_indices = torch.randperm(total_slots, device="cuda")[:num_tokens].to(torch.int32) + + k_data_ptrs_src, v_data_ptrs_src, src_slot_stride, src_head_stride = create_pointer_tensors( + k_buffers_src, v_buffers_src + ) + + # Gather to FP8 pinned buffer + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_buffer = torch.empty(output_size, dtype=dtype, device="cpu", pin_memory=True) + + kv_elem_bytes = k_buffers_src[0].element_size() + + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs_src, + v_data_ptrs=v_data_ptrs_src, + slot_indices=slot_indices, + pinned_output=pinned_buffer, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + kv_elem_bytes=kv_elem_bytes, + ) + + # Destination KV buffers in FP8 + k_buffers_dst = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + v_buffers_dst = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=dtype, device="cuda") + for _ in range(num_layers) + ] + + k_data_ptrs_dst, v_data_ptrs_dst, dst_slot_stride, dst_head_stride = create_pointer_tensors( + k_buffers_dst, v_buffers_dst + ) + + # Scatter from FP8 pinned buffer + scatter_kv_with_staging_all_layers( + pinned_input=pinned_buffer, + k_data_ptrs=k_data_ptrs_dst, + v_data_ptrs=v_data_ptrs_dst, + slot_indices=slot_indices, + head_start=0, + num_heads_to_scatter=num_heads, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + kv_elem_bytes=kv_elem_bytes, + ) + + # Verify roundtrip (compare as bytes) + for layer_idx in range(num_layers): + src_k = k_buffers_src[layer_idx][slot_indices.long()].view(torch.int8) + dst_k = k_buffers_dst[layer_idx][slot_indices.long()].view(torch.int8) + torch.testing.assert_close(dst_k, src_k) + + src_v = v_buffers_src[layer_idx][slot_indices.long()].view(torch.int8) + dst_v = v_buffers_dst[layer_idx][slot_indices.long()].view(torch.int8) + torch.testing.assert_close(dst_v, src_v) + + +# ============================================================================= +# Dtype Mismatch Tests (Validation) +# ============================================================================= + + +def test_gather_kv_dtype_mismatch_raises(): + """ + Test that mismatched KV cache and pinned buffer dtypes raise an assertion error. + + This catches the bug where FP8 KV cache (1 byte) was used with bfloat16 pinned buffer (2 bytes), + causing incorrect pointer arithmetic and memory corruption. + """ + num_layers = 1 + num_tokens = 64 + num_heads = 4 + head_dim = 64 + total_slots = 256 + + # KV cache in float16 (2 bytes) + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda") + for _ in range(num_layers) + ] + + slot_indices = torch.arange(num_tokens, device="cuda", dtype=torch.int32) + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + + # Pinned buffer in bfloat16 (2 bytes) - same size, so no assertion from size check + pinned_output = torch.empty(output_size, dtype=torch.bfloat16, device="cpu", pin_memory=True) + + # Pass kv_elem_bytes=1 to simulate FP8 KV cache with bfloat16 pinned buffer + # This should raise an assertion error + with pytest.raises(AssertionError, match="does not match"): + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_output, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + kv_elem_bytes=1, # Simulate FP8 (1 byte) while pinned buffer is bfloat16 (2 bytes) + ) + + +def test_scatter_kv_dtype_mismatch_raises(): + """ + Test that mismatched KV cache and pinned buffer dtypes raise an assertion error. + + This catches the bug where FP8 KV cache (1 byte) was used with bfloat16 pinned buffer (2 bytes), + causing incorrect pointer arithmetic and memory corruption. + """ + num_layers = 1 + num_tokens = 64 + num_heads = 4 + head_dim = 64 + total_slots = 256 + + input_size = num_layers * 2 * num_tokens * num_heads * head_dim + # Pinned buffer in bfloat16 (2 bytes) + pinned_input = torch.randn(input_size, dtype=torch.bfloat16, device="cpu", pin_memory=True) + + slot_indices = torch.arange(num_tokens, device="cuda", dtype=torch.int32) + + k_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda") + for _ in range(num_layers) + ] + v_buffers = [ + torch.zeros(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda") + for _ in range(num_layers) + ] + + k_data_ptrs, v_data_ptrs, dst_slot_stride, dst_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + # Pass kv_elem_bytes=1 to simulate FP8 KV cache with bfloat16 pinned buffer + # This should raise an assertion error + with pytest.raises(AssertionError, match="does not match"): + scatter_kv_with_staging_all_layers( + pinned_input=pinned_input, + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + head_start=0, + num_heads_to_scatter=num_heads, + num_layers=num_layers, + head_dim=head_dim, + dst_slot_stride=dst_slot_stride, + dst_head_stride=dst_head_stride, + kv_elem_bytes=1, # Simulate FP8 (1 byte) while pinned buffer is bfloat16 (2 bytes) + ) + + +@pytest.mark.skipif(not has_fp8_support(), reason="FP8 requires SM89+ GPU") +def test_gather_kv_fp8_with_wrong_pinned_dtype_raises(): + """ + Test the actual failure case: FP8 KV cache with bfloat16 pinned buffer. + + This is the exact bug that was causing illegal memory access. + """ + num_layers = 1 + num_tokens = 64 + num_heads = 4 + head_dim = 64 + total_slots = 256 + dtype_kv = torch.float8_e4m3fn + dtype_pinned = torch.bfloat16 + + # KV cache in FP8 (1 byte) + k_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda").to(dtype_kv) + for _ in range(num_layers) + ] + v_buffers = [ + torch.randn(total_slots, num_heads, head_dim, dtype=torch.float16, device="cuda").to(dtype_kv) + for _ in range(num_layers) + ] + + slot_indices = torch.arange(num_tokens, device="cuda", dtype=torch.int32) + + k_data_ptrs, v_data_ptrs, src_slot_stride, src_head_stride = create_pointer_tensors(k_buffers, v_buffers) + + # Pinned buffer in bfloat16 (2 bytes) - WRONG dtype! + output_size = num_layers * 2 * num_tokens * num_heads * head_dim + pinned_output = torch.empty(output_size, dtype=dtype_pinned, device="cpu", pin_memory=True) + + kv_elem_bytes = k_buffers[0].element_size() # 1 byte for FP8 + + # This should raise an assertion because pinned buffer is bfloat16 (2 bytes) + # but kv_elem_bytes is 1 (FP8) + with pytest.raises(AssertionError, match="does not match"): + gather_kv_to_pinned_all_layers( + k_data_ptrs=k_data_ptrs, + v_data_ptrs=v_data_ptrs, + slot_indices=slot_indices, + pinned_output=pinned_output, + head_start=0, + num_heads_to_gather=num_heads, + num_layers=num_layers, + head_dim=head_dim, + src_slot_stride=src_slot_stride, + src_head_stride=src_head_stride, + kv_elem_bytes=kv_elem_bytes, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])