-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Refactor]refactor p2p connector #6551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -233,15 +233,11 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta) | |||||
| remote_block_ids, local_block_ids | ||||||
| ) | ||||||
|
|
||||||
| block_length = len(self.block_len) | ||||||
| for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( | ||||||
| zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) | ||||||
| ): | ||||||
| if self.use_mla: | ||||||
| block_len = self.block_len[k % 2] | ||||||
| elif self.use_sparse: | ||||||
| block_len = self.block_len[k % 3] | ||||||
| else: | ||||||
| block_len = self.block_len[0] | ||||||
| block_len = self.block_len[k % block_length] | ||||||
| for group_remote_block_id, group_local_block_id in zip( | ||||||
| grouped_remote_block_ids, grouped_local_block_ids | ||||||
| ): | ||||||
|
|
@@ -925,48 +921,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |||||
| first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2 | ||||||
| ) | ||||||
| self.use_sparse = len(first_kv_cache_tuple) == 3 | ||||||
| if self.use_mla: | ||||||
| # MLA case.[num_block, block_size, 1, hidden_dim] | ||||||
| self.num_blocks = first_kv_cache.shape[0] | ||||||
| block_rank = 3 # [block_size, latent_dim] | ||||||
| block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] | ||||||
| block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] | ||||||
| self.block_len = [ | ||||||
| first_kv_cache[0].element_size() * math.prod(block_shape_norm), | ||||||
| first_kv_cache[1].element_size() * math.prod(block_shape_pe), | ||||||
| ] | ||||||
| logger.info( | ||||||
| "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", | ||||||
| self.num_blocks, | ||||||
| block_shape_norm, | ||||||
| block_shape_pe, | ||||||
| ) | ||||||
| elif self.use_sparse: | ||||||
| self.num_blocks = first_kv_cache.shape[0] | ||||||
| block_rank = 3 # [block_size, latent_dim] | ||||||
| block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] | ||||||
| block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] | ||||||
| block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:] | ||||||
| self.block_len = [ | ||||||
| first_kv_cache[0].element_size() * math.prod(block_shape_norm), | ||||||
| first_kv_cache[1].element_size() * math.prod(block_shape_pe), | ||||||
| first_kv_cache[2].element_size() * math.prod(block_shape_k), | ||||||
| ] | ||||||
| logger.info( | ||||||
| "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", | ||||||
| self.num_blocks, | ||||||
| block_shape_norm, | ||||||
| block_shape_pe, | ||||||
| block_shape_k, | ||||||
| ) | ||||||
|
|
||||||
| self.num_blocks = first_kv_cache.shape[0] | ||||||
| logger.info("num_blocks: %s", self.num_blocks) | ||||||
| block_rank = 3 | ||||||
| self.block_len = [] | ||||||
| if self.use_mla or self.use_sparse: | ||||||
| for i in range(len(first_kv_cache_tuple)): | ||||||
| block_shape = first_kv_cache_tuple[i].shape[-block_rank:] | ||||||
| logger.info("block_shape: %s", block_shape) | ||||||
| self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape)) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the other file, there's a potential bug here.
Suggested change
|
||||||
| else: | ||||||
| # [num_block, block_size, num_head, hidden_dim] | ||||||
| self.num_blocks = first_kv_cache.shape[0] | ||||||
| kv_elem_size = first_kv_cache.element_size() | ||||||
| block_rank = 3 # [block_size, kv_heads, head_dim] | ||||||
| block_shape = first_kv_cache.shape[-block_rank:] | ||||||
| self.block_len = [kv_elem_size * math.prod(block_shape)] | ||||||
| logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) | ||||||
| logger.info("block_shape: %s", block_shape) | ||||||
| self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)] | ||||||
|
|
||||||
| logger.info( | ||||||
| "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", | ||||||
|
|
@@ -979,30 +948,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |||||
| kv_caches_base_addr = [] | ||||||
| ptrs = [] | ||||||
| lengths = [] | ||||||
| length = len(self.block_len) | ||||||
| for cache_or_caches in kv_caches.values(): | ||||||
| # Normalize to always be a list of caches | ||||||
| if self.use_mla: | ||||||
| for i, cache in enumerate(cache_or_caches, 0): | ||||||
| base_addr = cache.data_ptr() | ||||||
| region_len = self.num_blocks * self.block_len[i % 2] | ||||||
| kv_caches_base_addr.append(base_addr) | ||||||
| ptrs.append(base_addr) | ||||||
| lengths.append(region_len) | ||||||
| elif self.use_sparse: | ||||||
| for i, cache in enumerate(cache_or_caches, 0): | ||||||
| base_addr = cache.data_ptr() | ||||||
| region_len = self.num_blocks * self.block_len[i % 3] | ||||||
| kv_caches_base_addr.append(base_addr) | ||||||
| ptrs.append(base_addr) | ||||||
| lengths.append(region_len) | ||||||
| else: | ||||||
| cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches | ||||||
| for cache in cache_list: | ||||||
| base_addr = cache.data_ptr() | ||||||
| region_len = self.num_blocks * self.block_len[0] | ||||||
| kv_caches_base_addr.append(base_addr) | ||||||
| ptrs.append(base_addr) | ||||||
| lengths.append(region_len) | ||||||
| for i, cache in enumerate(cache_or_caches, 0): | ||||||
| base_addr = cache.data_ptr() | ||||||
| region_len = self.num_blocks * self.block_len[i % length] | ||||||
| kv_caches_base_addr.append(base_addr) | ||||||
| ptrs.append(base_addr) | ||||||
| lengths.append(region_len) | ||||||
| global_te.register_buffer(ptrs, lengths) | ||||||
| self.kv_caches_base_addr = kv_caches_base_addr | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a potential bug here.
first_kv_cacheis assignedfirst_kv_cache_tuple[0], so it's the first tensor in the tuple. When calculatingblock_len, you are usingfirst_kv_cache[i].element_size(), which is equivalent tofirst_kv_cache.element_size()for any validi. This means you are using the element size of the first tensor (first_kv_cache_tuple[0]) for all tensors infirst_kv_cache_tuple. If the tensors infirst_kv_cache_tuplehave different dtypes, this will lead to an incorrectblock_lencalculation. It should probably befirst_kv_cache_tuple[i].element_size()to get the element size of the correct tensor in the tuple.