Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def __init__(
enable=enable_memory_saver
)

# used for chunked cpu-offloading
self.cpu_offloading_chunk_size = 8192

@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
Expand Down Expand Up @@ -157,6 +160,12 @@ def transfer_per_layer(self, indices, flat_data, layer_id):
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter

def get_cpu_copy(self, indices):
raise NotImplementedError()

def load_cpu_copy(self, kv_cache_cpu, indices):
raise NotImplementedError()


class TokenToKVPoolAllocator:
"""An allocator managing the indices to kv cache data."""
Expand Down Expand Up @@ -280,8 +289,6 @@ def __init__(

self._create_buffers()

# used for chunked cpu-offloading
self.chunk_size = 8192
self.layer_transfer_counter = None
self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None
Expand Down Expand Up @@ -378,10 +385,11 @@ def maybe_get_custom_mem_pool(self):
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
kv_cache_cpu.append([])
for i in range(0, len(indices), self.chunk_size):
chunk_indices = indices[i : i + self.chunk_size]
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
Expand All @@ -394,12 +402,13 @@ def get_cpu_copy(self, indices):

def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize()
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
for i in range(0, len(indices), self.chunk_size):
chunk_indices = indices[i : i + self.chunk_size]
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
k_cpu, v_cpu = (
kv_cache_cpu[layer_id][i // self.chunk_size][0],
kv_cache_cpu[layer_id][i // self.chunk_size][1],
kv_cache_cpu[layer_id][i // chunk_size][0],
kv_cache_cpu[layer_id][i // chunk_size][1],
)
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
Expand Down Expand Up @@ -724,6 +733,33 @@ def transfer_per_layer(self, indices, flat_data, layer_id):
flat_data = flat_data.to(device=self.device, non_blocking=False)
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data

def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
kv_cache_cpu.append([])
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
kv_cache_cpu[-1].append(kv_cpu)
torch.cuda.synchronize()
return kv_cache_cpu

def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize()
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
assert kv_cpu.shape[0] == len(chunk_indices)
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
torch.cuda.synchronize()


class DoubleSparseTokenToKVPool(KVCache):
def __init__(
Expand Down
Loading