diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py index 3083abe22cf8..ae4788cb5dd4 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py @@ -29,6 +29,29 @@ def generate_batch_query_keys(kv_num: int, config: HiCacheStorageConfig): return set_keys, get_keys, exist_keys +def create_mock_host_kv_cache(buffer_size, dtype=torch.float32): + """Create a mock HostKVCache-like object for testing.""" + buffer = torch.randn(buffer_size, dtype=dtype) + + class MockHostKVCache: + def __init__(self, buffer): + self.kv_buffer = buffer + self.layout = "page_first" + self.page_size = 1 # Simple page size for testing + + def get_page_buffer_meta(self, indices): + """Mock implementation of get_page_buffer_meta.""" + ptr_list = [] + element_size_list = [] + for idx in indices: + # Create mock pointers and sizes for each page + ptr_list.append(idx * self.page_size * self.kv_buffer.element_size()) + element_size_list.append(self.page_size * self.kv_buffer.element_size()) + return ptr_list, element_size_list + + return MockHostKVCache(buffer), buffer + + def test_single_operation(): """Test the set API with a single key-value pair.""" print("=" * 100) @@ -37,8 +60,11 @@ def test_single_operation(): buffer_size = 1024 * 1024 * 16 # 16MB value_elements = 1024 store = MooncakeStore() - buffer = torch.randn(buffer_size, dtype=torch.float32) - store.register_buffer(buffer) + mock_host_kv_cache, buffer = create_mock_host_kv_cache(buffer_size) + + # Register the memory pool host - this is the proper workflow + store.register_mem_pool_host(mock_host_kv_cache) + value_size = value_elements * buffer.element_size() key = str(uuid.uuid4()) @@ -75,8 +101,10 @@ def test_batch_operation(config: HiCacheStorageConfig): value_elements = 256 kv_num = 13 store = MooncakeStore(config) - buffer = torch.randn(buffer_size, dtype=torch.float32) - store.register_buffer(buffer) + mock_host_kv_cache, buffer = create_mock_host_kv_cache(buffer_size) + + store.register_mem_pool_host(mock_host_kv_cache) + value_size = value_elements * buffer.element_size() set_keys, get_keys, exist_keys = generate_batch_query_keys(kv_num, config)