Skip to content
48 changes: 46 additions & 2 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def _generate_storage_config(
pp_size=self.pp_size,
is_mla_model=is_mla_backend,
enable_storage_metrics=self.enable_storage_metrics,
is_page_first_layout=self.mem_pool_host.layout == "page_first",
layout=self.mem_pool_host.layout,
model_name=model_name,
tp_lcm_size=tp_lcm_size,
should_split_heads=should_split_heads,
Expand Down Expand Up @@ -827,6 +827,50 @@ def _page_get_zero_copy(
inc += self.page_size
operation.increment(inc)

@staticmethod
def _count_consecutive_true(results: List[bool]) -> int:
for i, ok in enumerate(results):
if not ok:
return i
return len(results)

def _kv_get_pages(self, hash_values, host_indices, extra_info=None) -> int:
if self.storage_backend_type == "mooncake":
results = self.storage_backend.batch_get_v1(
hash_values, host_indices, extra_info
)
return self._count_consecutive_true(results)

dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
]
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
if page_data is None:
return 0
success_pages = 0
for i in range(len(hash_values)):
if page_data[i] is None:
break
self.mem_pool_host.set_from_flat_data_page(
host_indices[i * self.page_size], page_data[i]
)
success_pages += 1
return success_pages

def _kv_set_pages(self, hash_values, host_indices, extra_info=None) -> bool:
if self.storage_backend_type == "mooncake":
return all(
self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info)
)
data = [
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values))
]
return self.storage_backend.batch_set(hash_values, data)

def _storage_hit_page_num(self, batch_hashes, extra_info=None) -> int:
return self.storage_backend.batch_exists(batch_hashes, extra_info)

# todo: deprecate
def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None):
dummy_page_dst = [
Expand Down Expand Up @@ -922,7 +966,7 @@ def _storage_hit_query(self, operation) -> tuple[list[str], int]:
)
batch_hashes.append(last_hash)
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info)
hit_page_num = self._storage_hit_page_num(batch_hashes, extra_info)
hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes):
Expand Down
Loading
Loading