Skip to content

[P/D]Improve the performance of Layerwise Connector#5303

Merged
yiz-liu merged 8 commits intovllm-project:mainfrom
nwpu-zxr:layerwise_opt
Dec 31, 2025
Merged

[P/D]Improve the performance of Layerwise Connector#5303
yiz-liu merged 8 commits intovllm-project:mainfrom
nwpu-zxr:layerwise_opt

Conversation

@nwpu-zxr
Copy link
Copy Markdown
Contributor

@nwpu-zxr nwpu-zxr commented Dec 24, 2025

What this PR does / why we need it?

Improve the performance of Layerwise Connector, mainly includes the following points:

  1. Use event synchronize to replace stream synchronize.
  2. Access metaserver when scheduling.
  3. Transfer kvcache each Chunk prefill segmentation.

This PR is related to [RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache Layerwise Push Support #4842

Does this PR introduce any user-facing change?

No.

How was this patch tested?

By CI.

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several performance improvements to the Layerwise Connector. The main changes include replacing stream synchronization with event synchronization for better asynchrony, accessing a metaserver during scheduling, and adjusting KV cache transfer logic. The switch to event-based synchronization is a good optimization. However, I've found some critical issues related to resource management and error handling in the new metaserver access logic that need to be addressed.

Comment on lines +673 to +687
def _access_metaserver(self, url, message):
success = False
retry = 0
while retry < 3 and success is False:
retry += 1
try:
self.metaserver_client.post(url, json=message)
success = True
except Exception as e:
logger.error(
f"Failed to connect to metaserver: {url}, retry {retry} time."
)
if retry == 3:
raise e

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _access_metaserver method does not handle HTTP errors correctly. The httpx.Client.post method does not raise an exception for HTTP error status codes (e.g., 4xx, 5xx). The current implementation will treat such responses as successful, which can lead to silent failures and incorrect behavior. The retry logic also lacks a backoff delay, which can put unnecessary load on the metaserver during failures. The function also fails silently if all retries fail without an exception being caught on the last attempt.

    def _access_metaserver(self, url, message):
        last_exception = None
        for attempt in range(1, 4):
            try:
                response = self.metaserver_client.post(url, json=message)
                response.raise_for_status()  # Raise an exception for 4xx/5xx status codes
                return
            except Exception as e:
                last_exception = e
                logger.error(
                    f"Failed to connect to metaserver: {url}, attempt {attempt}/3. Error: {e}"
                )
                if attempt < 3:
                    time.sleep(1)  # Add a 1-second delay before retrying
        if last_exception:
            raise last_exception

Comment on lines +516 to +519
self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000),
timeout=None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ThreadPoolExecutor and httpx.Client are initialized here but are never shut down or closed. This will lead to resource leaks (threads, file descriptors). It's crucial to add a shutdown method to the MooncakeLayerwiseConnectorScheduler class to clean up these resources.

Additionally, setting timeout=None for the httpx.Client is risky in a production system, as a request to the metaserver could hang indefinitely, blocking a thread from the pool. It's strongly recommended to set a reasonable timeout.

It's recommended to add a shutdown method to the MooncakeLayerwiseConnectorScheduler class to properly close the ThreadPoolExecutor and httpx.Client to prevent resource leaks.

def shutdown(self):
    self.executor.shutdown(wait=True)
    self.metaserver_client.close()

This method should be called when the scheduler is no longer needed.

Suggested change
self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000),
timeout=None)
self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000),
timeout=60.0)

if self.current_layer != layer_index:
self.current_layer = layer_index
self.model_stream.synchronize()
reshape_cache_event.synchronize()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you encountered any hang issues during testing? #4976

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our self-validation, we used the latest 8.5 version of the CANN package, and did not encounter any hanging issues with event synchronize. And we will add the relevant explanations.

Comment thread vllm_ascend/attention/mla_v1.py Outdated
prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
attn_metadata.prefill.reshape_cache_event = torch.npu.Event()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to be constrained such that the event is recorded only for P nodes and not merely for prefill.

remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
future = self.executor.submit(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the request for P nodes forward to the schedule stage, and delete the corresponding logic in the worker.

@liziyu179 liziyu179 force-pushed the layerwise_opt branch 5 times, most recently from a9e81e5 to f9e5a71 Compare December 24, 2025 08:57
Comment thread vllm_ascend/attention/mla_v1.py Outdated
num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.is_kv_producer:
attn_metadata.prefill.reshape_cache_event = torch.npu.Event()
Copy link
Copy Markdown
Contributor

@wujinyuan1 wujinyuan1 Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the long sequence CP mode need no this extra processing?

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

1 similar comment
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@liziyu179 liziyu179 force-pushed the layerwise_opt branch 2 times, most recently from 4f7088d to 22c306b Compare December 27, 2025 06:24
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

nwpu-zxr and others added 3 commits December 29, 2025 14:46
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
@liziyu179 liziyu179 force-pushed the layerwise_opt branch 2 times, most recently from b692e63 to cf72d91 Compare December 29, 2025 07:28
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
@wangxiaoteng888 wangxiaoteng888 force-pushed the layerwise_opt branch 2 times, most recently from 437dec8 to 0430743 Compare December 29, 2025 12:49
wangxiaoteng888 and others added 3 commits December 29, 2025 20:50
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
@wangxiaoteng888 wangxiaoteng888 force-pushed the layerwise_opt branch 4 times, most recently from 7442647 to 279f2f2 Compare December 30, 2025 11:33
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
@yiz-liu yiz-liu added ready read for review ready-for-test start test by label for PR labels Dec 31, 2025
@yiz-liu yiz-liu merged commit 46a1614 into vllm-project:main Dec 31, 2025
62 of 66 checks passed
wjunLu pushed a commit to wjunLu/vllm-ascend that referenced this pull request Jan 4, 2026
### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: wjunLu <wjunlu217@gmail.com>
Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
@nwpu-zxr nwpu-zxr deleted the layerwise_opt branch February 27, 2026 09:16
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants