From a042d9bca4dfaca77b45a4c4ddc1fff313637750 Mon Sep 17 00:00:00 2001 From: nwpu-zxr Date: Tue, 23 Dec 2025 21:21:53 +0800 Subject: [PATCH 1/7] add event synchronize && schedule access metaserver Signed-off-by: nwpu-zxr --- vllm_ascend/attention/attention_v1.py | 7 +- vllm_ascend/attention/mla_v1.py | 2 + .../mooncake_layerwise_connector.py | 70 ++++++++++++++++--- 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 854ac0330be..e1b588836f6 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -169,7 +169,8 @@ class AscendMetadata: causal: bool = True # runner_type in model_config. model_runner_type: str = "" - + # prefill reshape_and_cache event + reshape_cache_event: torch.npu.Event = None class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). @@ -628,6 +629,8 @@ def reshape_and_cache( ): if len(kv_cache) > 1: + if self.vllm_config.kv_transfer_config.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping @@ -648,6 +651,8 @@ def reshape_and_cache( key_cache=self.key_cache, value_cache=self.value_cache, slot_indices=slots[:attn_metadata.num_actual_tokens]) + if self.vllm_config.kv_transfer_config.is_kv_producer: + attn_metadata.reshape_cache_event.record() return key, value def forward_impl( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ee51b0763c7..341f3ca9f5b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -105,6 +105,7 @@ class AscendMLAPrefillMetadata: sin: torch.Tensor = None cos: torch.Tensor = None pcp_metadata: Optional[AscendPCPMetadata] = None + reshape_cache_event: torch.npu.Event = None @dataclass @@ -695,6 +696,7 @@ def __init__( kv_sharing_target_layer_name: Optional[str], **kwargs, ): + self.vllm_config = get_current_vllm_config() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index d1351049726..4416e67839a 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -144,7 +144,7 @@ def __init__(self, raise RuntimeError("Mooncake memory registration failed. ") self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor, - torch.Tensor]]() + torch.Tensor, torch.npu.Event]]() self.ready_event = ready_event self.callback_func = callback_func @@ -155,15 +155,15 @@ def run(self): torch.npu.set_device(device) self.ready_event.set() while True: - req_id, req_meta, layer_index, key, value = self.send_queue.get() - self._handle_request(req_id, req_meta, layer_index, key, value) + req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get() + self._handle_request(req_id, req_meta, layer_index, key, value, reshape_cache_event) - def _handle_request(self, req_id, req_meta, layer_index, key, value): + def _handle_request(self, req_id, req_meta, layer_index, key, value, reshape_cache_event): try: logger.debug( f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." ) - self._transfer_kv_cache(req_id, req_meta, layer_index, key, value) + self._transfer_kv_cache(req_id, req_meta, layer_index, key, value, reshape_cache_event) logger.debug( f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." ) @@ -171,7 +171,7 @@ def _handle_request(self, req_id, req_meta, layer_index, key, value): logger.error("Failed to transfer KV cache for request " f"{req_id}: {e}") - def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): + def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, reshape_cache_event): # send kv layer to remote if len(req_meta.local_block_ids) == 0: logger.debug( @@ -227,7 +227,7 @@ def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): length_list.append(length) if self.current_layer != layer_index: self.current_layer = layer_index - self.model_stream.synchronize() + reshape_cache_event.synchronize() ret = self.engine.batch_transfer_sync_write( session_id, src_list, dst_list, length_list) if ret < 0: @@ -512,6 +512,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._reqs_need_send_layerwise: dict[str, tuple[ int, list[int], Request]] = {} # req_id, (len(prompt), local_block_ids, request) + + self.executor = ThreadPoolExecutor(32) + self.metaserver_client = httpx.Client( + limits=httpx.Limits(max_connections=100000), + timeout=None) def get_num_new_matched_tokens( self, request: "Request", @@ -571,6 +576,36 @@ def update_state_after_alloc(self, request: "Request", params["do_remote_prefill"] = False + logger.info( + f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}" + ) + # All parameters here should appear in the returned dict of + # request_finished in the scheduler side except "request_id". + kv_transfer_params = dict( + token_ids=[], + request_id=request.request_id, + do_remote_prefill=False, + do_remote_decode=True, + remote_block_ids=local_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + ) + future = self.executor.submit( + self._access_metaserver, + url=params.get("metaserver", None), + message=kv_transfer_params, + ) + + def handle_exception(future): + if future.exception(): + logger.error( + f"Access metaserver fail: {future.exception()}" + ) + + future.add_done_callback(handle_exception) + + # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): local_block_ids = (blocks.get_block_ids()[0]) @@ -635,6 +670,21 @@ def build_connector_meta( ) return meta + def _access_metaserver(self, url, message): + success = False + retry = 0 + while retry < 3 and success is False: + retry += 1 + try: + self.metaserver_client.post(url, json=message) + success = True + except Exception as e: + logger.error( + f"Failed to connect to metaserver: {url}, retry {retry} time." + ) + if retry == 3: + raise e + def request_finished( self, request: "Request", @@ -907,6 +957,10 @@ def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( ): # enable decode prefix cache + if self.use_mla: + reshape_cache_event = attn_metadata[layer_name].prefill.reshape_cache_event + else: + reshape_cache_event = attn_metadata.reshape_cache_event for request in connector_metadata.requests.values(): assert len(request.local_block_ids) >= len( request.remote_block_ids @@ -965,7 +1019,7 @@ def sort_kv_cache(input_kv: list[list[int]]): ) assert self.kv_send_layer_thread is not None self.kv_send_layer_thread.send_queue.put( - (req_id, req_meta_update, self.current_layer, key, value)) + (req_id, req_meta_update, self.current_layer, key, value, reshape_cache_event)) self.current_layer += 1 def _get_remote_socket( From 6cc02d536375f9c89acc6e9e02920a0e7be9139a Mon Sep 17 00:00:00 2001 From: liziyu Date: Wed, 24 Dec 2025 11:40:05 +0800 Subject: [PATCH 2/7] del worker assess metaserver && bugfix Signed-off-by: liziyu --- vllm_ascend/attention/attention_v1.py | 3 +- vllm_ascend/attention/mla_v1.py | 2 + .../mooncake_layerwise_connector.py | 56 +++---------------- 3 files changed, 11 insertions(+), 50 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index e1b588836f6..0f5fd1c2660 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -315,6 +315,7 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.key_cache = None self.value_cache = None + self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, @@ -629,7 +630,7 @@ def reshape_and_cache( ): if len(kv_cache) > 1: - if self.vllm_config.kv_transfer_config.is_kv_producer: + if self.is_kv_producer: attn_metadata.reshape_cache_event = torch.npu.Event() if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 341f3ca9f5b..cdebd6033a6 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -743,6 +743,8 @@ def __init__( self.speculative_config = self.vllm_config.speculative_config self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO + self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + def _v_up_proj(self, x): # Convert from (N, B, L)/(N, B, 1, L) to (N, B, L) x = x.view(self.num_heads, -1, self.kv_lora_rank) diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 4416e67839a..c0148bb9228 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -227,6 +227,12 @@ def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, reshape_ length_list.append(length) if self.current_layer != layer_index: self.current_layer = layer_index + """ + Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. + This issue will be fixed in CANN version 8.5.rc1. + You can manually build the master branch of the project at https://gitcode.com/cann/hixl + to resolve this issue before the 8.5.RC1 release. + """ reshape_cache_event.synchronize() ret = self.engine.batch_transfer_sync_write( session_id, src_list, dst_list, length_list) @@ -726,11 +732,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.total_layers = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) - self.executor = ThreadPoolExecutor(32) - self.metaserver_client = httpx.Client( - limits=httpx.Limits(max_connections=100000), - timeout=None) if self.tp_rank == 0 else None - # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + @@ -884,21 +885,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_recv_layer_thread.start() ready_event.wait() - def _access_metaserver(self, url, message): - success = False - retry = 0 - while retry < 3 and success is False: - retry += 1 - try: - self.metaserver_client.post(url, json=message) - success = True - except Exception as e: - logger.error( - f"Failed to connect to metaserver: {url}, retry {retry} time." - ) - if retry == 3: - raise e - def get_finished(self) -> tuple[set[str], set[str]]: done_recving = ( self.kv_recv_layer_thread. @@ -915,35 +901,6 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): self.current_layer = 0 if self.vllm_config.kv_transfer_config.is_kv_consumer: for req_id, meta in metadata.requests.items(): - if self.tp_rank % self.tp_size == 0: - logger.info( - f"Send request: {req_id} to proxy metaserver: {meta.metaserver}" - ) - # All parameters here should appear in the returned dict of - # request_finished in the scheduler side except "request_id". - kv_transfer_params = dict( - token_ids=meta.token_ids, - request_id=req_id, - do_remote_prefill=False, - do_remote_decode=True, - remote_block_ids=meta.local_block_ids, - remote_engine_id=self.engine_id, - remote_host=self.side_channel_host, - remote_port=self.side_channel_port, - ) - future = self.executor.submit( - self._access_metaserver, - url=meta.metaserver, - message=kv_transfer_params, - ) - - def handle_exception(future): - if future.exception(): - logger.error( - f"Access metaserver fail: {future.exception()}" - ) - - future.add_done_callback(handle_exception) assert self.kv_recv_layer_thread is not None with self.kv_recv_layer_thread.lock: self.kv_recv_layer_thread.task_tracker[req_id] = 0 @@ -1018,6 +975,7 @@ def sort_kv_cache(input_kv: list[list[int]]): f"Add request {req_id} to kv send layer thread. {req_meta_update=}" ) assert self.kv_send_layer_thread is not None + assert reshape_cache_event is not None self.kv_send_layer_thread.send_queue.put( (req_id, req_meta_update, self.current_layer, key, value, reshape_cache_event)) self.current_layer += 1 From 4c8d93d796d0c7c9abc1aaf57321fbd4bfd7fa2f Mon Sep 17 00:00:00 2001 From: liziyu Date: Wed, 24 Dec 2025 13:14:23 +0800 Subject: [PATCH 3/7] push kv cache for each chunk Signed-off-by: liziyu --- .../mooncake_layerwise_connector.py | 164 +++++++++++------- 1 file changed, 98 insertions(+), 66 deletions(-) diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index c0148bb9228..2bc0029efaf 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -65,6 +65,32 @@ class ReqMeta: remote_te_rpc_port: Optional[int] remote_kv_caches_base_addr: Optional[list[int]] metaserver: Optional[str] + chunk_finish: Optional[bool] + + +@dataclass +class SendReqInfo: + local_block_ids: list[int] + remote_block_ids: List[int] + remote_cache_tokens: int + local_transfered_tokens: int + local_computed_tokens: int + request: "Request" + + def extend_local_block_ids(self, new_block_ids: List[int]) -> None: + """extend local block ids for this step""" + self.local_block_ids.extend(new_block_ids) + + def update_computed_tokens(self, computed_tokens: int) -> None: + """update local computen tokens for this step""" + self.local_computed_tokens = computed_tokens + + def update_transfered_tokens(self, transferred_tokens: int) -> None: + """update transfered tokens for this step""" + self.local_transfered_tokens = transferred_tokens + + def unpack(self): + return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transfered_tokens, self.local_computed_tokens, self.request @dataclass @@ -172,12 +198,6 @@ def _handle_request(self, req_id, req_meta, layer_index, key, value, reshape_cac f"{req_id}: {e}") def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, reshape_cache_event): - # send kv layer to remote - if len(req_meta.local_block_ids) == 0: - logger.debug( - f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer." - ) - return # not need to send kv cache if self.tp_rank % self.num_head_replica != 0: logger.debug( @@ -291,7 +311,7 @@ def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, reshape_ logger.error("Mooncake transfer failed for request %s", req_id) raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") - if layer_index == (self.total_layers - 1): + if layer_index == (self.total_layers - 1) and req_meta.chunk_finish: self.callback_func(req_id, req_meta) @@ -382,7 +402,8 @@ def add_new_req(self, request_id: str, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - token_ids: Optional[list[int]] = None): + token_ids: Optional[list[int]] = None, + chunk_finish: bool = False): self.requests[request_id] = ReqMeta( token_ids=token_ids or [], local_block_ids=local_block_ids, @@ -395,6 +416,7 @@ def add_new_req(self, remote_kv_caches_base_addr=kv_transfer_params.get( "remote_kv_caches_base_addr", None), metaserver=kv_transfer_params.get("metaserver", None), + chunk_finish=chunk_finish ) @@ -515,9 +537,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[str, tuple[Request, list[int], list[int]]] = {} - self._reqs_need_send_layerwise: dict[str, tuple[ - int, list[int], - Request]] = {} # req_id, (len(prompt), local_block_ids, request) + self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {} self.executor = ThreadPoolExecutor(32) self.metaserver_client = httpx.Client( @@ -618,8 +638,11 @@ def handle_exception(future): logger.debug( f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue" ) - self._reqs_need_send_layerwise[request.request_id] = (len( - request.all_token_ids), local_block_ids, request) + remote_block_ids = copy.deepcopy(params["remote_block_ids"]) + remote_cache_tokens = ((len(request.all_token_ids) + self.block_size - 1) // self.block_size - len(remote_block_ids)) * self.block_size + local_transfered_tokens = remote_cache_tokens + local_computed_tokens = None + self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(local_block_ids=local_block_ids, remote_block_ids=remote_block_ids, remote_cache_tokens=remote_cache_tokens, local_transfered_tokens=local_transfered_tokens, local_computed_tokens=local_computed_tokens, request=request) def build_connector_meta( self, @@ -627,53 +650,67 @@ def build_connector_meta( ) -> KVConnectorMetadata: meta = MooncakeLayerwiseConnectorMetadata() - # Loop through scheduled reqs and convert to ReqMeta. - for req_id, (req, token_ids, - block_ids) in self._reqs_need_recv.items(): - assert req.kv_transfer_params is not None - # For the case where there are no remote blocks to pull - # (block_ids is empty), we don't need to schedule - # an async read on the worker side. - meta.add_new_req(request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - token_ids=token_ids) - - # Clear the list once workers start the transfers - self._reqs_need_recv.clear() - - cached_reqs = scheduler_output.scheduled_cached_reqs - new_reqs = scheduler_output.scheduled_new_reqs - for req_id, new_blocks in zip(cached_reqs.req_ids, - cached_reqs.new_block_ids): - if req_id in self._reqs_need_send_layerwise and new_blocks is not None: - total_tokens, block_ids, req = self._reqs_need_send_layerwise[ - req_id] - block_ids.extend(new_blocks[0]) - - computed_tokens = dict( - list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + - [(x.req_id, x.num_computed_tokens) for x in new_reqs]) - for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( - ): - if req_id in self._reqs_need_send_layerwise: - total_tokens, block_ids, req = self._reqs_need_send_layerwise[ - req_id] - current_tokens = computed_tokens.get(req_id, - 0) + scheduled_tokens - if current_tokens >= total_tokens: - logger.debug( - f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" - ) - meta.add_new_req(request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - token_ids=[]) - self._reqs_need_send_layerwise.pop(req_id) - else: - logger.debug( - f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" - ) + if self.vllm_config.kv_transfer_config.is_kv_consumer: + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, token_ids, + block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=token_ids) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + else: + cached_reqs = scheduler_output.scheduled_cached_reqs + new_reqs = scheduler_output.scheduled_new_reqs + scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + # update local block ids + for req_id, new_blocks in zip(cached_reqs.req_ids, + cached_reqs.new_block_ids): + if req_id in self._reqs_need_send_layerwise and new_blocks is not None: + self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks[0]) + + computed_tokens = dict( + list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + + [(x.req_id, x.num_computed_tokens) for x in new_reqs]) + for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( + ): + if req_id in self._reqs_need_send_layerwise: + send_req_info = self._reqs_need_send_layerwise[req_id] + # update local computed tokens, not transfer spec decode tokens + spec_decode_tokens = len(scheduled_spec_decode_tokens[req_id]) if (req_id in scheduled_spec_decode_tokens) else 0 + send_req_info.update_computed_tokens(computed_tokens.get(req_id,0) + scheduled_tokens - spec_decode_tokens) + + def add_tranfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False): + local_block_ids, remote_block_ids, remote_cache_tokens, local_transfered_tokens, local_computed_tokens, request = send_req_info.unpack() + local_trans_block_ids = local_block_ids[(local_transfered_tokens // self.block_size): (local_computed_tokens // self.block_size)] + remote_trans_block_ids = remote_block_ids[((local_transfered_tokens - remote_cache_tokens) // self.block_size): ((local_computed_tokens - remote_cache_tokens) // self.block_size)] + request.kv_transfer_params["remote_block_ids"] = remote_trans_block_ids + assert len(local_trans_block_ids)==len(remote_trans_block_ids), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}" + adjusted_tokens = local_computed_tokens - (self.block_size - 1) if chunk_finish else local_computed_tokens + logger.info(f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}") + meta.add_new_req(request_id=req_id, + local_block_ids=local_trans_block_ids, + kv_transfer_params=request.kv_transfer_params, + token_ids=[], + chunk_finish = chunk_finish) + # update local_transfered_tokens + local_transfered_tokens = (local_computed_tokens // self.block_size) * self.block_size + send_req_info.update_transfered_tokens(local_transfered_tokens) + + # no chunk or last chunk + if send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids): + send_req_info.update_computed_tokens(send_req_info.local_computed_tokens + self.block_size - 1) + add_tranfer_task(req_id, send_req_info, chunk_finish=True) + self._reqs_need_send_layerwise.pop(req_id) + # chunk + elif (send_req_info.local_computed_tokens // self.block_size) - (send_req_info.local_transfered_tokens // self.block_size) > 0: + add_tranfer_task(req_id, send_req_info) return meta def _access_metaserver(self, url, message): @@ -918,12 +955,7 @@ def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, reshape_cache_event = attn_metadata[layer_name].prefill.reshape_cache_event else: reshape_cache_event = attn_metadata.reshape_cache_event - for request in connector_metadata.requests.values(): - assert len(request.local_block_ids) >= len( - request.remote_block_ids - ), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num." - request.local_block_ids = request.local_block_ids[ - -len(request.remote_block_ids):] + if self.pd_head_ratio != 1: def sort_kv_cache(input_kv: list[list[int]]): From dc1e7058411ea8c20a7a255de24e297d8ab6231c Mon Sep 17 00:00:00 2001 From: nwpu-zxr Date: Mon, 29 Dec 2025 14:56:42 +0800 Subject: [PATCH 4/7] fix lint Signed-off-by: nwpu-zxr --- vllm_ascend/attention/attention_v1.py | 3 +- vllm_ascend/attention/mla_v1.py | 6 +- .../mooncake_layerwise_connector.py | 161 +++++++++++------- 3 files changed, 110 insertions(+), 60 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 0f5fd1c2660..7626c1def61 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -172,6 +172,7 @@ class AscendMetadata: # prefill reshape_and_cache event reshape_cache_event: torch.npu.Event = None + class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ @@ -652,7 +653,7 @@ def reshape_and_cache( key_cache=self.key_cache, value_cache=self.value_cache, slot_indices=slots[:attn_metadata.num_actual_tokens]) - if self.vllm_config.kv_transfer_config.is_kv_producer: + if self.is_kv_producer: attn_metadata.reshape_cache_event.record() return key, value diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index cdebd6033a6..fd27975b956 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -105,7 +105,6 @@ class AscendMLAPrefillMetadata: sin: torch.Tensor = None cos: torch.Tensor = None pcp_metadata: Optional[AscendPCPMetadata] = None - reshape_cache_event: torch.npu.Event = None @dataclass @@ -164,6 +163,7 @@ class AscendMLAMetadata: decode: Optional[AscendMLADecodeMetadata] = None prefill: Optional[AscendMLAPrefillMetadata] = None + reshape_cache_event: torch.npu.Event = None def __post_init__(self): pass @@ -1325,8 +1325,12 @@ def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, prefill_slots = attn_metadata.slot_mapping[ num_decode_tokens:num_actual_tokens] prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() prefill_k_nope, prefill_value = self.kv_b_proj( prefill_k_c_normed)[0].view( -1, self.num_heads, diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 2bc0029efaf..12c37ef86c2 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -73,7 +73,7 @@ class SendReqInfo: local_block_ids: list[int] remote_block_ids: List[int] remote_cache_tokens: int - local_transfered_tokens: int + local_transferred_tokens: int local_computed_tokens: int request: "Request" @@ -84,13 +84,13 @@ def extend_local_block_ids(self, new_block_ids: List[int]) -> None: def update_computed_tokens(self, computed_tokens: int) -> None: """update local computen tokens for this step""" self.local_computed_tokens = computed_tokens - - def update_transfered_tokens(self, transferred_tokens: int) -> None: - """update transfered tokens for this step""" - self.local_transfered_tokens = transferred_tokens + + def update_transferred_tokens(self, transferred_tokens: int) -> None: + """update transferred tokens for this step""" + self.local_transferred_tokens = transferred_tokens def unpack(self): - return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transfered_tokens, self.local_computed_tokens, self.request + return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request @dataclass @@ -181,15 +181,19 @@ def run(self): torch.npu.set_device(device) self.ready_event.set() while True: - req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get() - self._handle_request(req_id, req_meta, layer_index, key, value, reshape_cache_event) + req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get( + ) + self._handle_request(req_id, req_meta, layer_index, key, value, + reshape_cache_event) - def _handle_request(self, req_id, req_meta, layer_index, key, value, reshape_cache_event): + def _handle_request(self, req_id, req_meta, layer_index, key, value, + reshape_cache_event): try: logger.debug( f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." ) - self._transfer_kv_cache(req_id, req_meta, layer_index, key, value, reshape_cache_event) + self._transfer_kv_cache(req_id, req_meta, layer_index, key, value, + reshape_cache_event) logger.debug( f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." ) @@ -197,7 +201,8 @@ def _handle_request(self, req_id, req_meta, layer_index, key, value, reshape_cac logger.error("Failed to transfer KV cache for request " f"{req_id}: {e}") - def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, reshape_cache_event): + def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, + reshape_cache_event): # not need to send kv cache if self.tp_rank % self.num_head_replica != 0: logger.debug( @@ -416,8 +421,7 @@ def add_new_req(self, remote_kv_caches_base_addr=kv_transfer_params.get( "remote_kv_caches_base_addr", None), metaserver=kv_transfer_params.get("metaserver", None), - chunk_finish=chunk_finish - ) + chunk_finish=chunk_finish) class MooncakeLayerwiseConnector(KVConnectorBase_V1): @@ -538,11 +542,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._reqs_need_recv: dict[str, tuple[Request, list[int], list[int]]] = {} self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {} - + self.executor = ThreadPoolExecutor(32) self.metaserver_client = httpx.Client( - limits=httpx.Limits(max_connections=100000), - timeout=None) + limits=httpx.Limits(max_connections=100000), timeout=None) def get_num_new_matched_tokens( self, request: "Request", @@ -626,12 +629,10 @@ def update_state_after_alloc(self, request: "Request", def handle_exception(future): if future.exception(): logger.error( - f"Access metaserver fail: {future.exception()}" - ) + f"Access metaserver fail: {future.exception()}") future.add_done_callback(handle_exception) - # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): local_block_ids = (blocks.get_block_ids()[0]) @@ -639,10 +640,18 @@ def handle_exception(future): f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue" ) remote_block_ids = copy.deepcopy(params["remote_block_ids"]) - remote_cache_tokens = ((len(request.all_token_ids) + self.block_size - 1) // self.block_size - len(remote_block_ids)) * self.block_size - local_transfered_tokens = remote_cache_tokens - local_computed_tokens = None - self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(local_block_ids=local_block_ids, remote_block_ids=remote_block_ids, remote_cache_tokens=remote_cache_tokens, local_transfered_tokens=local_transfered_tokens, local_computed_tokens=local_computed_tokens, request=request) + remote_cache_tokens = ( + (len(request.all_token_ids) + self.block_size - 1) // + self.block_size - len(remote_block_ids)) * self.block_size + local_transferred_tokens = remote_cache_tokens + local_computed_tokens = 0 + self._reqs_need_send_layerwise[request.request_id] = SendReqInfo( + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, + remote_cache_tokens=remote_cache_tokens, + local_transferred_tokens=local_transferred_tokens, + local_computed_tokens=local_computed_tokens, + request=request) def build_connector_meta( self, @@ -653,15 +662,15 @@ def build_connector_meta( if self.vllm_config.kv_transfer_config.is_kv_consumer: # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, token_ids, - block_ids) in self._reqs_need_recv.items(): + block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None # For the case where there are no remote blocks to pull # (block_ids is empty), we don't need to schedule # an async read on the worker side. meta.add_new_req(request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - token_ids=token_ids) + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=token_ids) # Clear the list once workers start the transfers self._reqs_need_recv.clear() @@ -671,45 +680,79 @@ def build_connector_meta( scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens # update local block ids for req_id, new_blocks in zip(cached_reqs.req_ids, - cached_reqs.new_block_ids): + cached_reqs.new_block_ids): if req_id in self._reqs_need_send_layerwise and new_blocks is not None: - self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks[0]) + self._reqs_need_send_layerwise[ + req_id].extend_local_block_ids(new_blocks[0]) computed_tokens = dict( - list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + - [(x.req_id, x.num_computed_tokens) for x in new_reqs]) + list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + + [(x.req_id, x.num_computed_tokens) for x in new_reqs]) for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( ): if req_id in self._reqs_need_send_layerwise: send_req_info = self._reqs_need_send_layerwise[req_id] # update local computed tokens, not transfer spec decode tokens - spec_decode_tokens = len(scheduled_spec_decode_tokens[req_id]) if (req_id in scheduled_spec_decode_tokens) else 0 - send_req_info.update_computed_tokens(computed_tokens.get(req_id,0) + scheduled_tokens - spec_decode_tokens) - - def add_tranfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False): - local_block_ids, remote_block_ids, remote_cache_tokens, local_transfered_tokens, local_computed_tokens, request = send_req_info.unpack() - local_trans_block_ids = local_block_ids[(local_transfered_tokens // self.block_size): (local_computed_tokens // self.block_size)] - remote_trans_block_ids = remote_block_ids[((local_transfered_tokens - remote_cache_tokens) // self.block_size): ((local_computed_tokens - remote_cache_tokens) // self.block_size)] - request.kv_transfer_params["remote_block_ids"] = remote_trans_block_ids - assert len(local_trans_block_ids)==len(remote_trans_block_ids), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}" - adjusted_tokens = local_computed_tokens - (self.block_size - 1) if chunk_finish else local_computed_tokens - logger.info(f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}") - meta.add_new_req(request_id=req_id, - local_block_ids=local_trans_block_ids, - kv_transfer_params=request.kv_transfer_params, - token_ids=[], - chunk_finish = chunk_finish) - # update local_transfered_tokens - local_transfered_tokens = (local_computed_tokens // self.block_size) * self.block_size - send_req_info.update_transfered_tokens(local_transfered_tokens) - + spec_decode_tokens = len( + scheduled_spec_decode_tokens[req_id]) if ( + req_id in scheduled_spec_decode_tokens) else 0 + send_req_info.update_computed_tokens( + computed_tokens.get(req_id, 0) + scheduled_tokens - + spec_decode_tokens) + + def add_tranfer_task(req_id, + send_req_info: SendReqInfo, + chunk_finish=False): + local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack( + ) + local_trans_block_ids = local_block_ids[( + local_transferred_tokens // + self.block_size):(local_computed_tokens // + self.block_size)] + remote_trans_block_ids = remote_block_ids[( + (local_transferred_tokens - remote_cache_tokens) // + self.block_size):((local_computed_tokens - + remote_cache_tokens) // + self.block_size)] + request.kv_transfer_params[ + "remote_block_ids"] = remote_trans_block_ids + assert len(local_trans_block_ids) == len( + remote_trans_block_ids + ), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}" + adjusted_tokens = local_computed_tokens - ( + self.block_size - + 1) if chunk_finish else local_computed_tokens + logger.info( + f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}" + ) + meta.add_new_req( + request_id=req_id, + local_block_ids=local_trans_block_ids, + kv_transfer_params=request.kv_transfer_params, + token_ids=[], + chunk_finish=chunk_finish) + # update local_transferred_tokens + local_transferred_tokens = ( + local_computed_tokens // + self.block_size) * self.block_size + send_req_info.update_transferred_tokens( + local_transferred_tokens) + # no chunk or last chunk - if send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids): - send_req_info.update_computed_tokens(send_req_info.local_computed_tokens + self.block_size - 1) - add_tranfer_task(req_id, send_req_info, chunk_finish=True) + if send_req_info.local_computed_tokens >= len( + send_req_info.request.all_token_ids): + send_req_info.update_computed_tokens( + send_req_info.local_computed_tokens + + self.block_size - 1) + add_tranfer_task(req_id, + send_req_info, + chunk_finish=True) self._reqs_need_send_layerwise.pop(req_id) # chunk - elif (send_req_info.local_computed_tokens // self.block_size) - (send_req_info.local_transfered_tokens // self.block_size) > 0: + elif (send_req_info.local_computed_tokens // + self.block_size) - ( + send_req_info.local_transferred_tokens // + self.block_size) > 0: add_tranfer_task(req_id, send_req_info) return meta @@ -952,10 +995,11 @@ def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, ): # enable decode prefix cache if self.use_mla: - reshape_cache_event = attn_metadata[layer_name].prefill.reshape_cache_event + reshape_cache_event = attn_metadata[ + layer_name].reshape_cache_event else: reshape_cache_event = attn_metadata.reshape_cache_event - + if self.pd_head_ratio != 1: def sort_kv_cache(input_kv: list[list[int]]): @@ -1009,7 +1053,8 @@ def sort_kv_cache(input_kv: list[list[int]]): assert self.kv_send_layer_thread is not None assert reshape_cache_event is not None self.kv_send_layer_thread.send_queue.put( - (req_id, req_meta_update, self.current_layer, key, value, reshape_cache_event)) + (req_id, req_meta_update, self.current_layer, key, value, + reshape_cache_event)) self.current_layer += 1 def _get_remote_socket( From 9e85e6820b416ca5172ff58352137a44d162df3e Mon Sep 17 00:00:00 2001 From: wangxiaoteng Date: Mon, 29 Dec 2025 20:41:41 +0800 Subject: [PATCH 5/7] refactoring Signed-off-by: wangxiaoteng --- tests/ut/attention/test_mla_v1.py | 1 + .../ut/kv_connector/test_mooncake_layerwise_connector.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 88d5071d7b9..ae51a8753b2 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -1112,6 +1112,7 @@ def test_mla_preprocess(self, magic_npu_fetch, MagicMock(), MagicMock() ] self.impl.num_kv_heads = self.impl.num_heads + self.impl.is_kv_producer = False decode_res, prefill_res = self.impl._mla_preprocess( "mock_layer", diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index e2f84d9f8d9..1006a6bf29d 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -71,7 +71,8 @@ def setUp(self): remote_port=7777, remote_te_rpc_port=6000, remote_kv_caches_base_addr=[4000, 8000, 14000, 18000], - metaserver="http://dummy") + metaserver="http://dummy", + chunk_finish=False) @patch( "vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr", @@ -176,6 +177,7 @@ def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): req_meta = self.req_meta_base req_meta.local_block_ids = [5, 6] req_meta.remote_block_ids = [10, 11] + req_meta = True req_meta.remote_kv_caches_base_addr = [ 7000, 8000, 9000, 10000, 11000, 12000 @@ -468,6 +470,7 @@ def test_build_connector_meta(self): request = MockRequest("req1") self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6]) + self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True request.kv_transfer_params = { "remote_block_ids": [1, 2, 3], "remote_engine_id": "remote", @@ -505,7 +508,8 @@ def __init__(self, cached_new_block_ids=None, cached_num_computed=None, new_reqs=None, - num_sched=None): + num_sched=None, + scheduled_spec_decode_tokens=None): self.scheduled_cached_reqs = SimpleNamespace( req_ids=cached_req_ids or [], new_block_ids=cached_new_block_ids or [], From 42548858c3d0ab360ddcee01402b0e5624b592cd Mon Sep 17 00:00:00 2001 From: liziyu Date: Tue, 30 Dec 2025 08:48:29 +0800 Subject: [PATCH 6/7] add super init for MooncakeLayerwiseConnector Signed-off-by: liziyu --- vllm_ascend/distributed/mooncake_layerwise_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 12c37ef86c2..9d9d9301a6f 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -430,6 +430,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional[KVCacheConfig] = None): + super().__init__(vllm_config, role, kv_cache_config) assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id self._connector_metadata = MooncakeLayerwiseConnectorMetadata() From 7a688805ab9c13cf4e75b9fc09c71e0d88921060 Mon Sep 17 00:00:00 2001 From: wangxiaoteng Date: Tue, 30 Dec 2025 09:52:49 +0800 Subject: [PATCH 7/7] add reshape_cache_event Signed-off-by: wangxiaoteng --- .../test_mooncake_layerwise_connector.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index 1006a6bf29d..c3d8a67f0fa 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -118,7 +118,8 @@ def test_transfer_pd_gt1_uses_buffers_and_calls_engine( req_meta=req_meta, layer_index=0, key=key, - value=value) + value=value, + reshape_cache_event=MagicMock()) self.engine.batch_transfer_sync_write.assert_called_once() session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[ @@ -143,8 +144,12 @@ def test_transfer_pd_gt1_uses_buffers_and_calls_engine( def test_transfer_skips_when_no_local_blocks(self): req_meta = self.req_meta_base req_meta.local_block_ids = [] - self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros( - (1, 8)), torch.zeros((1, 8))) + self.thread._transfer_kv_cache("req2", + req_meta, + 0, + torch.zeros((1, 8)), + torch.zeros((1, 8)), + reshape_cache_event=MagicMock()) self.engine.batch_transfer_sync_write.assert_not_called() def test_transfer_skips_when_tp_not_sender(self): @@ -162,8 +167,12 @@ def test_transfer_skips_when_tp_not_sender(self): first_kv_cache=self.first_kv_cache, callback_func=MagicMock()) req_meta = self.req_meta_base - thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)), - torch.zeros((1, 8))) + thread._transfer_kv_cache("req3", + req_meta, + 0, + torch.zeros((1, 8)), + torch.zeros((1, 8)), + reshape_cache_event=MagicMock()) self.engine.batch_transfer_sync_write.assert_not_called() @patch( @@ -190,7 +199,8 @@ def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): req_meta, layer_index=2, key=key, - value=value) + value=value, + reshape_cache_event=MagicMock()) self.thread.callback_func.assert_called_once() @@ -517,6 +527,7 @@ def __init__(self, ) self.scheduled_new_reqs = new_reqs or [] self.num_scheduled_tokens = num_sched or {} + self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {} class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): @@ -555,17 +566,19 @@ def test_update_state_after_alloc_prefill_records_and_resets_flag(self): def test_update_state_after_alloc_decode_records_send_layerwise(self): req = MockRequest("req_u2", prompt_token_ids=list(range(10)), - kv_transfer_params={"do_remote_decode": True}) + kv_transfer_params={ + "do_remote_decode": True, + "remote_block_ids": {} + }) blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], )) self.scheduler.update_state_after_alloc(req, blocks, num_external_tokens=0) self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise) - total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[ - "req_u2"] - self.assertEqual(total_tokens, 10) - self.assertEqual(local_block_ids, [7, 8, 9]) - self.assertIs(req_ref, req) + info = self.scheduler._reqs_need_send_layerwise["req_u2"] + self.assertEqual(info.local_block_ids, [7, 8, 9]) + self.assertIs(info.request, req) + self.assertEqual(info.remote_block_ids, []) def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self): req = MockRequest("req_b1",