Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ def to_string(self):


class ChunkedTokenDatabase:
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None):
def __init__(self, metadata: KeyMetadata, block_size: int, partitions: list[int] | None):
self.metadata = metadata
self.block_size = block_size
self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = []
self.partitions = partitions
Expand All @@ -117,29 +116,24 @@ def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
length = len(self.block_len)
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0]

addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr = base_addr + block_id * self.block_len[index % length]
size = int(self.block_len[index % length] / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
size_list.append(size)
return addr_list, size_list, block_id

def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
addr_list = []
size_list = []
length = len(self.block_len)
for i in range(length):
addr = self.kv_caches_base_addr[layer_id * length] + block_id * self.block_len[i]
size = int(self.block_len[i] / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(size)
return addr_list, size_list

def process_tokens(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self.use_mla = False
if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla:
self.use_mla = True
self.use_sparse = hasattr(model_config.hf_text_config, "index_topk")
self.use_layerwise = use_layerwize
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
for i in range(2, remaining_layers + 2):
partitions[-i] += 1

self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions)
self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, partitions)

real_backend = backend_map.get(self.backend.lower())

Expand All @@ -145,55 +146,42 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]

# TODO(tms): Find a more robust way to detect and handle MLA
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
)
self.num_blocks = first_kv_cache.shape[0]
logger.info("num_blocks: %s", self.num_blocks)
block_rank = 3
self.block_len = []
if self.use_mla or self.use_sparse:
for i in range(len(first_kv_cache_tuple)):
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
logger.info("block_shape: %s", block_shape)
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)

logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape)
logger.info("block_shape: %s", block_shape)
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]

logger.info(
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla,
self.use_sparse,
first_kv_cache.shape,
)

self.kv_caches = kv_caches
self.kv_caches_base_addr = []
ptrs = []
lengths = []
length = len(self.block_len)
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
region_len = self.num_blocks * self.block_len[i % 2]
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
region_len = self.num_blocks * self.block_len[0]
ptrs.append(base_addr)
lengths.append(region_len)
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % length]
self.kv_caches_base_addr.append(base_addr)
ptrs.append(base_addr)
lengths.append(region_len)

self.m_store.register_buffer(ptrs, lengths)
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
self.token_database.set_block_len(self.block_len)
Expand Down
Loading