diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 251d16aee24..24c1eead10b 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -952,7 +952,7 @@ def get_flat_data(self, indices): return self.kv_buffer[:, :, indices] def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[:, layer_id - self.start_layer, indices] + return self.kv_buffer[:, layer_id - self.device_pool.start_layer, indices] def assign_flat_data(self, indices, flat_data): self.kv_buffer[:, :, indices] = flat_data @@ -977,19 +977,19 @@ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_i for i in range(len(device_indices_cpu)): h_index = host_indices[i * self.page_size] d_index = device_indices_cpu[i] - device_pool.k_buffer[layer_id - self.start_layer][ + device_pool.k_buffer[layer_id - self.device_pool.start_layer][ d_index : d_index + self.page_size ].copy_( self.kv_buffer[ - 0, layer_id - self.start_layer, h_index : h_index + self.page_size + 0, layer_id - self.device_pool.start_layer, h_index : h_index + self.page_size ], non_blocking=True, ) - device_pool.v_buffer[layer_id - self.start_layer][ + device_pool.v_buffer[layer_id - self.device_pool.start_layer][ d_index : d_index + self.page_size ].copy_( self.kv_buffer[ - 1, layer_id - self.start_layer, h_index : h_index + self.page_size + 1, layer_id - self.device_pool.start_layer, h_index : h_index + self.page_size ], non_blocking=True, ) @@ -1045,7 +1045,7 @@ def get_flat_data(self, indices): return self.kv_buffer[:, indices] def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[layer_id - self.start_layer, indices] + return self.kv_buffer[layer_id - self.device_pool.start_layer, indices] def assign_flat_data(self, indices, flat_data): self.kv_buffer[:, indices] = flat_data @@ -1066,11 +1066,11 @@ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_i for i in range(len(device_indices_cpu)): h_index = host_indices[i * self.page_size] d_index = device_indices_cpu[i] - device_pool.kv_buffer[layer_id - self.start_layer][ + device_pool.kv_buffer[layer_id - self.device_pool.start_layer][ d_index : d_index + self.page_size ].copy_( self.kv_buffer[ - layer_id - self.start_layer, h_index : h_index + self.page_size + layer_id - self.device_pool.start_layer, h_index : h_index + self.page_size ], non_blocking=True, )