diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 0b8199d9f7e9..276070b51897 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -1109,6 +1109,7 @@ def __init__( self.indexer_page_stride_size = ( self.indexer_size_per_token * self.page_size * self.indexer_dtype.itemsize ) + self.indexer_layout_dim = self.indexer_page_stride_size * self.layer_num self.indexer_page_num = (self.size + self.page_size + 1) // self.page_size self._init_indexer_buffers() logger.info( @@ -1124,29 +1125,45 @@ def get_size_per_token(self): def _init_indexer_buffers(self): alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device] - self.index_k_with_scale_buffer = [ - alloc_func( - (self.indexer_page_num, self.indexer_page_stride_size), - dtype=self.indexer_dtype, - device=self.device, - pin_memory=self.pin_memory, - allocator=self.allocator, - ) - for _ in range(self.layer_num) - ] - self.index_k_data_refs = [ - self.index_k_with_scale_buffer[i] for i in range(self.layer_num) - ] - self.index_k_data_ptrs = torch.tensor( - [x.data_ptr() for x in self.index_k_data_refs], - dtype=torch.uint64, - device=self.device_pool.device, - ) self.index_k_device_ptrs = torch.tensor( [x.data_ptr() for x in self.device_pool.index_k_with_scale_buffer], dtype=torch.uint64, device=self.device_pool.device, ) + if self.layout == "layer_first": + self.index_k_with_scale_buffer = [ + alloc_func( + (self.indexer_page_num, self.indexer_page_stride_size), + dtype=self.indexer_dtype, + device=self.device, + pin_memory=self.pin_memory, + allocator=self.allocator, + ) + for _ in range(self.layer_num) + ] + self.index_k_data_refs = [ + self.index_k_with_scale_buffer[i] for i in range(self.layer_num) + ] + self.index_k_data_ptrs = torch.tensor( + [x.data_ptr() for x in self.index_k_data_refs], + dtype=torch.uint64, + device=self.device_pool.device, + ) + elif self.layout in ["page_first", "page_first_direct"]: + self.index_k_with_scale_buffer = alloc_func( + ( + self.indexer_page_num, + self.layer_num, + 1, + self.indexer_page_stride_size, + ), + dtype=self.indexer_dtype, + device=self.device, + pin_memory=self.pin_memory, + allocator=self.allocator, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") def _get_indexer_page_indices(self, host_indices, device_indices): if host_indices.numel() == 0: @@ -1171,21 +1188,46 @@ def _load_indexer_to_device_per_layer( ) use_kernel = io_backend == "kernel" and self.indexer_page_stride_size % 8 == 0 if use_kernel: - transfer_kv_per_layer_mla( - src=self.index_k_with_scale_buffer[layer_id], - dst=device_pool.index_k_with_scale_buffer[layer_id], - src_indices=host_page_indices, - dst_indices=device_page_indices, - item_size=self.indexer_page_stride_size, - ) + if self.layout == "layer_first": + transfer_kv_per_layer_mla( + src=self.index_k_with_scale_buffer[layer_id], + dst=device_pool.index_k_with_scale_buffer[layer_id], + src_indices=host_page_indices, + dst_indices=device_page_indices, + item_size=self.indexer_page_stride_size, + ) + elif self.layout == "page_first": + transfer_kv_per_layer_mla_pf_lf( + src=self.index_k_with_scale_buffer, + dst=device_pool.index_k_with_scale_buffer[layer_id], + src_indices=host_page_indices, + dst_indices=device_page_indices, + layer_id=layer_id, + item_size=self.indexer_page_stride_size, + src_layout_dim=self.indexer_layout_dim, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") else: - transfer_kv_direct( - src_layers=[self.index_k_with_scale_buffer[layer_id]], - dst_layers=[device_pool.index_k_with_scale_buffer[layer_id]], - src_indices=host_page_indices, - dst_indices=device_page_indices, - page_size=1, - ) + if self.layout == "layer_first": + transfer_kv_direct( + src_layers=[self.index_k_with_scale_buffer[layer_id]], + dst_layers=[device_pool.index_k_with_scale_buffer[layer_id]], + src_indices=host_page_indices, + dst_indices=device_page_indices, + page_size=1, + ) + elif self.layout == "page_first_direct": + transfer_kv_per_layer_direct_pf_lf( + src_ptrs=[self.index_k_with_scale_buffer], + dst_ptrs=[device_pool.index_k_with_scale_buffer[layer_id]], + src_indices=host_page_indices, + dst_indices=device_page_indices, + layer_id=layer_id, + page_size=1, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") def _backup_indexer_from_device_all_layer( self, device_pool, host_indices, device_indices, io_backend @@ -1195,22 +1237,46 @@ def _backup_indexer_from_device_all_layer( ) use_kernel = io_backend == "kernel" and self.indexer_page_stride_size % 8 == 0 if use_kernel: - transfer_kv_all_layer_mla( - src_layers=self.index_k_device_ptrs, - dst_layers=self.index_k_data_ptrs, - src_indices=device_page_indices, - dst_indices=host_page_indices, - item_size=self.indexer_page_stride_size, - num_layers=self.layer_num, - ) + if self.layout == "layer_first": + transfer_kv_all_layer_mla( + src_layers=self.index_k_device_ptrs, + dst_layers=self.index_k_data_ptrs, + src_indices=device_page_indices, + dst_indices=host_page_indices, + item_size=self.indexer_page_stride_size, + num_layers=self.layer_num, + ) + elif self.layout == "page_first": + transfer_kv_all_layer_mla_lf_pf( + src_layers=self.index_k_device_ptrs, + dst=self.index_k_with_scale_buffer, + src_indices=device_page_indices, + dst_indices=host_page_indices, + item_size=self.indexer_page_stride_size, + dst_layout_dim=self.indexer_layout_dim, + num_layers=self.layer_num, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") else: - transfer_kv_direct( - src_layers=device_pool.index_k_with_scale_buffer, - dst_layers=self.index_k_with_scale_buffer, - src_indices=device_page_indices, - dst_indices=host_page_indices, - page_size=1, - ) + if self.layout == "layer_first": + transfer_kv_direct( + src_layers=device_pool.index_k_with_scale_buffer, + dst_layers=self.index_k_with_scale_buffer, + src_indices=device_page_indices, + dst_indices=host_page_indices, + page_size=1, + ) + elif self.layout == "page_first_direct": + transfer_kv_all_layer_direct_lf_pf( + src_ptrs=device_pool.index_k_with_scale_buffer, + dst_ptrs=[self.index_k_with_scale_buffer], + src_indices=device_page_indices, + dst_indices=host_page_indices, + page_size=1, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") def load_to_device_per_layer( self,