Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 113 additions & 47 deletions python/sglang/srt/mem_cache/memory_pool_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,7 @@ def __init__(
self.indexer_page_stride_size = (
self.indexer_size_per_token * self.page_size * self.indexer_dtype.itemsize
)
self.indexer_layout_dim = self.indexer_page_stride_size * self.layer_num
self.indexer_page_num = (self.size + self.page_size + 1) // self.page_size
self._init_indexer_buffers()
logger.info(
Expand All @@ -1124,29 +1125,45 @@ def get_size_per_token(self):

def _init_indexer_buffers(self):
alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
self.index_k_with_scale_buffer = [
alloc_func(
(self.indexer_page_num, self.indexer_page_stride_size),
dtype=self.indexer_dtype,
device=self.device,
pin_memory=self.pin_memory,
allocator=self.allocator,
)
for _ in range(self.layer_num)
]
self.index_k_data_refs = [
self.index_k_with_scale_buffer[i] for i in range(self.layer_num)
]
self.index_k_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.index_k_data_refs],
dtype=torch.uint64,
device=self.device_pool.device,
)
self.index_k_device_ptrs = torch.tensor(
[x.data_ptr() for x in self.device_pool.index_k_with_scale_buffer],
dtype=torch.uint64,
device=self.device_pool.device,
)
if self.layout == "layer_first":
self.index_k_with_scale_buffer = [
alloc_func(
(self.indexer_page_num, self.indexer_page_stride_size),
dtype=self.indexer_dtype,
device=self.device,
pin_memory=self.pin_memory,
allocator=self.allocator,
)
for _ in range(self.layer_num)
]
self.index_k_data_refs = [
self.index_k_with_scale_buffer[i] for i in range(self.layer_num)
]
self.index_k_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.index_k_data_refs],
dtype=torch.uint64,
device=self.device_pool.device,
)
elif self.layout in ["page_first", "page_first_direct"]:
self.index_k_with_scale_buffer = alloc_func(
(
self.indexer_page_num,
self.layer_num,
1,
self.indexer_page_stride_size,
),
dtype=self.indexer_dtype,
device=self.device,
pin_memory=self.pin_memory,
allocator=self.allocator,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")

def _get_indexer_page_indices(self, host_indices, device_indices):
if host_indices.numel() == 0:
Expand All @@ -1171,21 +1188,46 @@ def _load_indexer_to_device_per_layer(
)
use_kernel = io_backend == "kernel" and self.indexer_page_stride_size % 8 == 0
if use_kernel:
transfer_kv_per_layer_mla(
src=self.index_k_with_scale_buffer[layer_id],
dst=device_pool.index_k_with_scale_buffer[layer_id],
src_indices=host_page_indices,
dst_indices=device_page_indices,
item_size=self.indexer_page_stride_size,
)
if self.layout == "layer_first":
transfer_kv_per_layer_mla(
src=self.index_k_with_scale_buffer[layer_id],
dst=device_pool.index_k_with_scale_buffer[layer_id],
src_indices=host_page_indices,
dst_indices=device_page_indices,
item_size=self.indexer_page_stride_size,
)
elif self.layout == "page_first":
transfer_kv_per_layer_mla_pf_lf(
src=self.index_k_with_scale_buffer,
dst=device_pool.index_k_with_scale_buffer[layer_id],
src_indices=host_page_indices,
dst_indices=device_page_indices,
layer_id=layer_id,
item_size=self.indexer_page_stride_size,
src_layout_dim=self.indexer_layout_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
else:
transfer_kv_direct(
src_layers=[self.index_k_with_scale_buffer[layer_id]],
dst_layers=[device_pool.index_k_with_scale_buffer[layer_id]],
src_indices=host_page_indices,
dst_indices=device_page_indices,
page_size=1,
)
if self.layout == "layer_first":
transfer_kv_direct(
src_layers=[self.index_k_with_scale_buffer[layer_id]],
dst_layers=[device_pool.index_k_with_scale_buffer[layer_id]],
src_indices=host_page_indices,
dst_indices=device_page_indices,
page_size=1,
)
elif self.layout == "page_first_direct":
transfer_kv_per_layer_direct_pf_lf(
src_ptrs=[self.index_k_with_scale_buffer],
dst_ptrs=[device_pool.index_k_with_scale_buffer[layer_id]],
src_indices=host_page_indices,
dst_indices=device_page_indices,
layer_id=layer_id,
page_size=1,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")

def _backup_indexer_from_device_all_layer(
self, device_pool, host_indices, device_indices, io_backend
Expand All @@ -1195,22 +1237,46 @@ def _backup_indexer_from_device_all_layer(
)
use_kernel = io_backend == "kernel" and self.indexer_page_stride_size % 8 == 0
if use_kernel:
transfer_kv_all_layer_mla(
src_layers=self.index_k_device_ptrs,
dst_layers=self.index_k_data_ptrs,
src_indices=device_page_indices,
dst_indices=host_page_indices,
item_size=self.indexer_page_stride_size,
num_layers=self.layer_num,
)
if self.layout == "layer_first":
transfer_kv_all_layer_mla(
src_layers=self.index_k_device_ptrs,
dst_layers=self.index_k_data_ptrs,
src_indices=device_page_indices,
dst_indices=host_page_indices,
item_size=self.indexer_page_stride_size,
num_layers=self.layer_num,
)
elif self.layout == "page_first":
transfer_kv_all_layer_mla_lf_pf(
src_layers=self.index_k_device_ptrs,
dst=self.index_k_with_scale_buffer,
src_indices=device_page_indices,
dst_indices=host_page_indices,
item_size=self.indexer_page_stride_size,
dst_layout_dim=self.indexer_layout_dim,
num_layers=self.layer_num,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
else:
transfer_kv_direct(
src_layers=device_pool.index_k_with_scale_buffer,
dst_layers=self.index_k_with_scale_buffer,
src_indices=device_page_indices,
dst_indices=host_page_indices,
page_size=1,
)
if self.layout == "layer_first":
transfer_kv_direct(
src_layers=device_pool.index_k_with_scale_buffer,
dst_layers=self.index_k_with_scale_buffer,
src_indices=device_page_indices,
dst_indices=host_page_indices,
page_size=1,
)
elif self.layout == "page_first_direct":
transfer_kv_all_layer_direct_lf_pf(
src_ptrs=device_pool.index_k_with_scale_buffer,
dst_ptrs=[self.index_k_with_scale_buffer],
src_indices=device_page_indices,
dst_indices=host_page_indices,
page_size=1,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")

def load_to_device_per_layer(
self,
Expand Down
Loading