Skip to content

[feat] cross layer kvcache#4768

Closed
HF-001 wants to merge 16 commits intovllm-project:mainfrom
HF-001:cross_layer_kvcache_dev
Closed

[feat] cross layer kvcache#4768
HF-001 wants to merge 16 commits intovllm-project:mainfrom
HF-001:cross_layer_kvcache_dev

Conversation

@HF-001
Copy link
Copy Markdown
Contributor

@HF-001 HF-001 commented Dec 8, 2025

What this PR does / why we need it?

this pr is for #4140 , and refer to: vllm-project/vllm#27743

Following this PR, connectors can turn-on and adapt to the new layout.

This PR enables the model_runner_v1 to allocate the KV cache tensors, so that the KV data for all layers will be contiguous per block. This can yield a significant speed up the transfer time of KV transfers (e.g. X4), such in the case of OffloadingConnector. Currently, this new layout is disabled by default, and will only be enabled when using a connector which explicitly prefers this new layout. Also, this new layout is currently only supported for uniform (non HMA) models.

How was this patch tested?

`export CUDA_VISIBLE_DEVICES=6
export TP=1
export MODEL_PATH=/model/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
#export ASCEND_LAUNCH_BLOCKING=1
#export ASCEND_SLOG_PRINT_TO_STDOUT=1

python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --no-enable-prefix-caching --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.65 --max-model-len 32768 --trust-remote-code --disable-log-requests
--block-size 128
--kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}'`

Performance testing on qwen3-14b, result is :
截屏2025-12-08 10 25 24

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a cross-layer KV cache layout to improve performance for KV transfers, especially for offloading. The changes are mainly in vllm_ascend/worker/model_runner_v1.py to conditionally allocate a contiguous KV cache tensor for all layers. The attention backends are also updated to support this new layout.

The overall approach is sound, but I've found two critical issues in the implementation of allocate_uniform_kv_caches that could lead to incorrect behavior or runtime errors. One issue is related to the assumption of equal K and V cache sizes for MLA models, and the other is an incorrect permutation logic for non-MLA models. Please see the detailed comments for suggestions on how to fix them.

Comment on lines +4745 to +4760
else :
new_kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
new_kv_cache_shape[-1] = new_kv_cache_shape[-1]//2
logger.info("Allocating a cross layer KV cache of shape %s", new_kv_cache_shape)

# allocate one contiguous buffer for all layers
cross_layers_k_cache = (
torch.zeros(total_size//2, dtype=torch.int8, device=device)
.view(kv_cache_spec.dtype)
.view(new_kv_cache_shape)
)
cross_layers_v_cache = (
torch.zeros(total_size//2, dtype=torch.int8, device=device)
.view(kv_cache_spec.dtype)
.view(new_kv_cache_shape)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for calculating new_kv_cache_shape for MLA models and the subsequent allocation of cross_layers_k_cache and cross_layers_v_cache assumes that the key and value caches have the same size. Specifically, it divides head_size by 2 and uses this for both K and V caches. However, for MLA models, K and V caches can have different sizes, determined by kv_lora_rank and qk_rope_head_dim respectively. This will lead to incorrect tensor shapes and memory corruption if kv_lora_rank != qk_rope_head_dim. The allocation size total_size // 2 is also based on this incorrect assumption.

The K and V caches should be handled with their respective shapes and sizes. You should calculate their sizes and shapes separately based on kv_lora_rank and qk_rope_head_dim.

Comment on lines +4764 to +4768
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) if kv_cache_shape[kv_cache_stride_order[i]] != 2
]
if len(new_kv_cache_shape) != len(kv_cache_shape):
inv_order = [i - 1 for i in inv_order]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of inv_order for non-MLA models is incorrect. The list comprehension logic and the subsequent adjustment [i - 1 for i in inv_order] will produce a wrong permutation, which can contain negative indices. This will lead to incorrect memory access when permute(*inv_order) is called, causing data corruption or runtime errors. The logic needs to be revised to correctly compute the inverse permutation that brings the num_layers dimension to the front of the tensor.

@MengqingCao MengqingCao self-assigned this Dec 8, 2025
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Dec 8, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

01267596 added 4 commits December 8, 2025 03:41
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
@HF-001
Copy link
Copy Markdown
Contributor Author

HF-001 commented Dec 10, 2025

@Sparkheart @MengqingCao hi, this pr is ready,cloud you help review?

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: kx <1670186653@qq.com>
HF-001 and others added 3 commits December 11, 2025 09:40
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
@zzzzwwjj
Copy link
Copy Markdown
Collaborator

zzzzwwjj commented Dec 12, 2025

This PR cannot be merged yet. We need to align on the impact of KV cache layout on functionality and the proposed solutions before merging.
Functions that may be affected:

  1. KV Connector related, such as Prefill-Decode Disaggregation, KV Pool, KV offload. This feature will have an impact on the KV transmission function.
  2. KV calculation related, such as reshape_and_cache, paged_cache_load, and various attention ops. All of these ops need to support strided tensors.

cc @wangxiyuan @weijinqian0 @MengqingCao

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
@HF-001
Copy link
Copy Markdown
Contributor Author

HF-001 commented Dec 12, 2025

This PR cannot be merged yet. We need to align on the impact of KV cache layout on functionality and the proposed solutions before merging. Functions that may be affected:

  1. KV Connector related, such as Prefill-Decode Disaggregation, KV Pool, KV offload. This feature will have an impact on the KV transmission function.
  2. KV calculation related, such as reshape_and_cache, paged_cache_load, and various attention ops. All of these ops need to support strided tensors.

cc @wangxiyuan @weijinqian0 @MengqingCao

@zzzzwwjj hi,1. this PR will only take effect when using an offloading connector and preempt_cross_layer_blocks is true, and will not affect other scenarios. If PD Disaggregation scenarios need to be used, simple adaptation is required.
2. Stride_order will only have an impact when the offloading connector and preempt_cross_layer_blocks are true, and will not affect other scenarios. i think Stride_order will not impact reshape_and_cache, paged_cache_load`.

cc @wangxiyuan @weijinqian0 @MengqingCao

For the convenience of understanding, the following is the design diagram of this PR project.In fact, the cross_layers_k/v_cache used only adjusted the layout through stride_order.
截屏2025-12-12 09 43 57

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
@MengqingCao
Copy link
Copy Markdown
Collaborator

This PR cannot be merged yet. We need to align on the impact of KV cache layout on functionality and the proposed solutions before merging. Functions that may be affected:

  1. KV Connector related, such as Prefill-Decode Disaggregation, KV Pool, KV offload. This feature will have an impact on the KV transmission function.
  2. KV calculation related, such as reshape_and_cache, paged_cache_load, and various attention ops. All of these ops need to support strided tensors.

cc @wangxiyuan @weijinqian0 @MengqingCao

@zzzzwwjj hi,1. this PR will only take effect when using an offloading connector and preempt_cross_layer_blocks is true, and will not affect other scenarios. If PD Disaggregation scenarios need to be used, simple adaptation is required. 2. Stride_order will only have an impact when the offloading connector and preempt_cross_layer_blocks are true, and will not affect other scenarios. i think Stride_order will not impact reshape_and_cache, paged_cache_load`.

Actually the op reshape_and_cache doesn't support non-contigous tensor now, not sure how this pr works now when both offloading connector and preempt_cross_layer_blocks are enabled

@HF-001
Copy link
Copy Markdown
Contributor Author

HF-001 commented Dec 15, 2025

This PR cannot be merged yet. We need to align on the impact of KV cache layout on functionality and the proposed solutions before merging. Functions that may be affected:

  1. KV Connector related, such as Prefill-Decode Disaggregation, KV Pool, KV offload. This feature will have an impact on the KV transmission function.
  2. KV calculation related, such as reshape_and_cache, paged_cache_load, and various attention ops. All of these ops need to support strided tensors.

cc @wangxiyuan @weijinqian0 @MengqingCao

@zzzzwwjj hi,1. this PR will only take effect when using an offloading connector and preempt_cross_layer_blocks is true, and will not affect other scenarios. If PD Disaggregation scenarios need to be used, simple adaptation is required. 2. Stride_order will only have an impact when the offloading connector and preempt_cross_layer_blocks are true, and will not affect other scenarios. i think Stride_order will not impact reshape_and_cache, paged_cache_load`.

Actually the op reshape_and_cache doesn't support non-contigous tensor now, not sure how this pr works now when both offloading connector and preempt_cross_layer_blocks are enabled

@MengqingCao In this pr,We are using cross_layers_k_cache and cross_layers_v_cache (continuous tensors) instead of kv_caches (not used), so 'reshape_and_cache' is support. The shape of cross_layers_k_cache/cross_layers_v_cache is (num_blocks, num_layers, block_size, num_heads, head_size).During each transmission, the transmitted block is (num_layers, block_size, num_ heads, head_size) instead of the previous block (block_size, num_ heads, head_size).

in this pr, bytes of per transmitted block is num_layers times larger than the original, similar to batch transmission, which can greatly improve transmission efficiency.

maybe you can refer to: vllm-project/vllm#27742

@LCAIZJ
Copy link
Copy Markdown
Collaborator

LCAIZJ commented Dec 15, 2025

PD connector need also requires corresponding modifications to adapt to the new layout.

@HF-001
Copy link
Copy Markdown
Contributor Author

HF-001 commented Dec 15, 2025

PD connector need also requires corresponding modifications to adapt to the new layout.

@LCAIZJ Do you mean I need to add code for PD separation adjustment in this PR? Although it can be implemented, it would compromise the consistency with VLLM. Currently, VLLM has only been implemented in the kvcache offloading scenario.

@HF-001
Copy link
Copy Markdown
Contributor Author

HF-001 commented Dec 15, 2025

Actually the op reshape_and_cache doesn't support non-contigous tensor now, not sure how this pr works now when both offloading connector and preempt_cross_layer_blocks are enabled

@MengqingCao do you mean torch_npu._npu_reshape_and_cache()? I have tested the entire process and it can work normally.

@MengqingCao
Copy link
Copy Markdown
Collaborator

Actually the op reshape_and_cache doesn't support non-contigous tensor now, not sure how this pr works now when both offloading connector and preempt_cross_layer_blocks are enabled

@MengqingCao do you mean torch_npu._npu_reshape_and_cache()? I have tested the entire process and it can work normally.

Exactly, I'll debug locally with this pr later to check some details

@HF-001
Copy link
Copy Markdown
Contributor Author

HF-001 commented Dec 15, 2025

Exactly, I'll debug locally with this pr later to check some details

@MengqingCao you are right. torch_npu._npu_reshape_and_cache() currently can not support kvcaches, which is non-contigous. Can torch_npu._npu_reshape_and_cache() be optimized to support discontinuous tensors? Like reshape_and_cache() in VLLM.

@HF-001
Copy link
Copy Markdown
Contributor Author

HF-001 commented Dec 16, 2025

@MengqingCao @zzzzwwjj May I ask if there are any other operators that need to be optimized besides the torch_npu. _npu_reshape_and_cache() operator? If it's just torch_npu._npu_reshape_and_cache(), perhaps I can refer to the reshape_and_cache() operator in VLLM and optimize it in Ascend.

wangxiyuan added a commit that referenced this pull request Dec 18, 2025
I'd like to nominate @zzzzwwjj @realliujiaxu @LCAIZJ to join vLLM Ascend
committer team.

@zzzzwwjj
---
- Review Quality‌:
He has completed 80+reviews since April. 2025, include
#3232 (comment),
#4822 (comment),
#4768 (comment)
high quality review.

- Sustained Contributions
15+ Valuable bug fix and refactor is very good.

https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Azzzzwwjj+is%3Aclosed+review%3Aapproved
Continuous optimization of code architecture

https://github.com/vllm-project/vllm-ascend/pulls?q=author%3Azzzzwwjj+is%3Amerged

- Quality Contribution‌:
#1229
#1979
#4359
#4878

- Community Involvement‌: 
He lead the #1147, to
refactor AscendFusedMoE at the first time.
He shared topics about large-scale distributed inference and
reinforcement learning on vLLM-Ascend meetup on August 2nd.

@realliujiaxu
---
- Review Quality‌:
He has completed about [40+
reviews](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+commenter%3Arealliujiaxu+-author%3Arealliujiaxu+)
since September, include
#4868 (comment),
#2275 (comment).

- Sustained Contributions
He has completed (17
commits)[https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Arealliujiaxu+is%3Amerged],
continuously optimizing the performance of the MoE model.

- Quality Contribution‌:

Contributed the Flash Comm1 feature to the community, supporting both
eager and aclgraph execution modes, while compatible with multiple MoE
models including DeepSeek and GLM4.5.
  - #3334
  - #3420
  - #3015
  
  co-author:
  - #3495
  - #4868

- Community Involvement‌: 
1. Completed two major refactors, enabling vllm-ascend to evolve more
rapidly and robustly: [Linear
module](#2867) and
[rejection
sampler](#4975)
2. [fixed 8
bugs](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Arealliujiaxu+is%3Amerged+bugfix+)
in graph mode, spec decoding and async scheduling.

@LCAIZJ
---
- Review Quality‌: He's been the go-to reviewer for virtually all PD
disaggregation and KV Pool related PRs, having completed [30+
reviews](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+commenter%3ALCAIZJ+is%3Aopen+-author%3ALCAIZJ+)
since May 2025. Notable examples include
[discussion_r2553887360](#4345 (comment)),
[issuecomment-3540994801](#4161 (comment)),
and
[discussion_r2492593988](#3981 (comment)),
all demonstrating thorough and insightful feedback.
- Sustained and Quality Contributions: His contributions reflect a
strong grasp of both ‌vLLM‌ and ‌vLLM Ascend‌ codebases, particularly in
prefill-decode disaggregation and KV pool areas ([7 PRs
merged](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3ALCAIZJ+is%3Amerged+)).
Prefill-Decode Disaggregation: Delivered KV transfer functionality using
Mooncake TransferEngine and enabled layerwise KV transfer
#1568
#2602
KV Pool: Developed the foundational KV Pool infrastructure and migrated
it to the latest ADXL stack
#2913
#3350
- Quality Contribution‌:
#1568
#2602
#2913
#3350
- Community Involvement‌: 
He actively responds to [community
issues](https://github.com/vllm-project/vllm-ascend/issues?q=is%3Aissue%20commenter%3ALCAIZJ%20is%3Aopen%20-author%3ALCAIZJ),
continuously monitors functionality and accuracy issues related to PD
disaggregation and KV Pool, and proactively delivers [bug
fixes](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3ALCAIZJ+is%3Amerged+bugfix).
- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@MengqingCao
Copy link
Copy Markdown
Collaborator

@MengqingCao @zzzzwwjj May I ask if there are any other operators that need to be optimized besides the torch_npu. _npu_reshape_and_cache() operator? If it's just torch_npu._npu_reshape_and_cache(), perhaps I can refer to the reshape_and_cache() operator in VLLM and optimize it in Ascend.

Sorry for the delay, there does exsits other operators, mainly including all the attention ops. Thus I don't recomand you to do this now

@HF-001
Copy link
Copy Markdown
Contributor Author

HF-001 commented Dec 19, 2025

@MengqingCao @zzzzwwjj May I ask if there are any other operators that need to be optimized besides the torch_npu. _npu_reshape_and_cache() operator? If it's just torch_npu._npu_reshape_and_cache(), perhaps I can refer to the reshape_and_cache() operator in VLLM and optimize it in Ascend.

Sorry for the delay, there does exsits other operators, mainly including all the attention ops. Thus I don't recomand you to do this now

thank you

@HF-001 HF-001 closed this Dec 19, 2025
chenaoxuan pushed a commit to chenaoxuan/vllm-ascend that referenced this pull request Dec 20, 2025
…t#5152)

I'd like to nominate @zzzzwwjj @realliujiaxu @LCAIZJ to join vLLM Ascend
committer team.

@zzzzwwjj
---
- Review Quality‌:
He has completed 80+reviews since April. 2025, include
vllm-project#3232 (comment),
vllm-project#4822 (comment),
vllm-project#4768 (comment)
high quality review.

- Sustained Contributions
15+ Valuable bug fix and refactor is very good.

https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Azzzzwwjj+is%3Aclosed+review%3Aapproved
Continuous optimization of code architecture

https://github.com/vllm-project/vllm-ascend/pulls?q=author%3Azzzzwwjj+is%3Amerged

- Quality Contribution‌:
vllm-project#1229
vllm-project#1979
vllm-project#4359
vllm-project#4878

- Community Involvement‌: 
He lead the vllm-project#1147, to
refactor AscendFusedMoE at the first time.
He shared topics about large-scale distributed inference and
reinforcement learning on vLLM-Ascend meetup on August 2nd.

@realliujiaxu
---
- Review Quality‌:
He has completed about [40+
reviews](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+commenter%3Arealliujiaxu+-author%3Arealliujiaxu+)
since September, include
vllm-project#4868 (comment),
vllm-project#2275 (comment).

- Sustained Contributions
He has completed (17
commits)[https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Arealliujiaxu+is%3Amerged],
continuously optimizing the performance of the MoE model.

- Quality Contribution‌:

Contributed the Flash Comm1 feature to the community, supporting both
eager and aclgraph execution modes, while compatible with multiple MoE
models including DeepSeek and GLM4.5.
  - vllm-project#3334
  - vllm-project#3420
  - vllm-project#3015
  
  co-author:
  - vllm-project#3495
  - vllm-project#4868

- Community Involvement‌: 
1. Completed two major refactors, enabling vllm-ascend to evolve more
rapidly and robustly: [Linear
module](vllm-project#2867) and
[rejection
sampler](vllm-project#4975)
2. [fixed 8
bugs](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Arealliujiaxu+is%3Amerged+bugfix+)
in graph mode, spec decoding and async scheduling.

@LCAIZJ
---
- Review Quality‌: He's been the go-to reviewer for virtually all PD
disaggregation and KV Pool related PRs, having completed [30+
reviews](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+commenter%3ALCAIZJ+is%3Aopen+-author%3ALCAIZJ+)
since May 2025. Notable examples include
[discussion_r2553887360](vllm-project#4345 (comment)),
[issuecomment-3540994801](vllm-project#4161 (comment)),
and
[discussion_r2492593988](vllm-project#3981 (comment)),
all demonstrating thorough and insightful feedback.
- Sustained and Quality Contributions: His contributions reflect a
strong grasp of both ‌vLLM‌ and ‌vLLM Ascend‌ codebases, particularly in
prefill-decode disaggregation and KV pool areas ([7 PRs
merged](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3ALCAIZJ+is%3Amerged+)).
Prefill-Decode Disaggregation: Delivered KV transfer functionality using
Mooncake TransferEngine and enabled layerwise KV transfer
vllm-project#1568
vllm-project#2602
KV Pool: Developed the foundational KV Pool infrastructure and migrated
it to the latest ADXL stack
vllm-project#2913
vllm-project#3350
- Quality Contribution‌:
vllm-project#1568
vllm-project#2602
vllm-project#2913
vllm-project#3350
- Community Involvement‌: 
He actively responds to [community
issues](https://github.com/vllm-project/vllm-ascend/issues?q=is%3Aissue%20commenter%3ALCAIZJ%20is%3Aopen%20-author%3ALCAIZJ),
continuously monitors functionality and accuracy issues related to PD
disaggregation and KV Pool, and proactively delivers [bug
fixes](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3ALCAIZJ+is%3Amerged+bugfix).
- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…t#5152)

I'd like to nominate @zzzzwwjj @realliujiaxu @LCAIZJ to join vLLM Ascend
committer team.

@zzzzwwjj
---
- Review Quality‌:
He has completed 80+reviews since April. 2025, include
vllm-project#3232 (comment),
vllm-project#4822 (comment),
vllm-project#4768 (comment)
high quality review.

- Sustained Contributions
15+ Valuable bug fix and refactor is very good.

https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Azzzzwwjj+is%3Aclosed+review%3Aapproved
Continuous optimization of code architecture

https://github.com/vllm-project/vllm-ascend/pulls?q=author%3Azzzzwwjj+is%3Amerged

- Quality Contribution‌:
vllm-project#1229
vllm-project#1979
vllm-project#4359
vllm-project#4878

- Community Involvement‌:
He lead the vllm-project#1147, to
refactor AscendFusedMoE at the first time.
He shared topics about large-scale distributed inference and
reinforcement learning on vLLM-Ascend meetup on August 2nd.

@realliujiaxu
---
- Review Quality‌:
He has completed about [40+
reviews](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+commenter%3Arealliujiaxu+-author%3Arealliujiaxu+)
since September, include
vllm-project#4868 (comment),
vllm-project#2275 (comment).

- Sustained Contributions
He has completed (17
commits)[https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Arealliujiaxu+is%3Amerged],
continuously optimizing the performance of the MoE model.

- Quality Contribution‌:

Contributed the Flash Comm1 feature to the community, supporting both
eager and aclgraph execution modes, while compatible with multiple MoE
models including DeepSeek and GLM4.5.
  - vllm-project#3334
  - vllm-project#3420
  - vllm-project#3015

  co-author:
  - vllm-project#3495
  - vllm-project#4868

- Community Involvement‌:
1. Completed two major refactors, enabling vllm-ascend to evolve more
rapidly and robustly: [Linear
module](vllm-project#2867) and
[rejection
sampler](vllm-project#4975)
2. [fixed 8
bugs](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Arealliujiaxu+is%3Amerged+bugfix+)
in graph mode, spec decoding and async scheduling.

@LCAIZJ
---
- Review Quality‌: He's been the go-to reviewer for virtually all PD
disaggregation and KV Pool related PRs, having completed [30+
reviews](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+commenter%3ALCAIZJ+is%3Aopen+-author%3ALCAIZJ+)
since May 2025. Notable examples include
[discussion_r2553887360](vllm-project#4345 (comment)),
[issuecomment-3540994801](vllm-project#4161 (comment)),
and
[discussion_r2492593988](vllm-project#3981 (comment)),
all demonstrating thorough and insightful feedback.
- Sustained and Quality Contributions: His contributions reflect a
strong grasp of both ‌vLLM‌ and ‌vLLM Ascend‌ codebases, particularly in
prefill-decode disaggregation and KV pool areas ([7 PRs
merged](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3ALCAIZJ+is%3Amerged+)).
Prefill-Decode Disaggregation: Delivered KV transfer functionality using
Mooncake TransferEngine and enabled layerwise KV transfer
vllm-project#1568
vllm-project#2602
KV Pool: Developed the foundational KV Pool infrastructure and migrated
it to the latest ADXL stack
vllm-project#2913
vllm-project#3350
- Quality Contribution‌:
vllm-project#1568
vllm-project#2602
vllm-project#2913
vllm-project#3350
- Community Involvement‌:
He actively responds to [community
issues](https://github.com/vllm-project/vllm-ascend/issues?q=is%3Aissue%20commenter%3ALCAIZJ%20is%3Aopen%20-author%3ALCAIZJ),
continuously monitors functionality and accuracy issues related to PD
disaggregation and KV Pool, and proactively delivers [bug
fixes](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3ALCAIZJ+is%3Amerged+bugfix).
- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…t#5152)

I'd like to nominate @zzzzwwjj @realliujiaxu @LCAIZJ to join vLLM Ascend
committer team.

@zzzzwwjj
---
- Review Quality‌:
He has completed 80+reviews since April. 2025, include
vllm-project#3232 (comment),
vllm-project#4822 (comment),
vllm-project#4768 (comment)
high quality review.

- Sustained Contributions
15+ Valuable bug fix and refactor is very good.

https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Azzzzwwjj+is%3Aclosed+review%3Aapproved
Continuous optimization of code architecture

https://github.com/vllm-project/vllm-ascend/pulls?q=author%3Azzzzwwjj+is%3Amerged

- Quality Contribution‌:
vllm-project#1229
vllm-project#1979
vllm-project#4359
vllm-project#4878

- Community Involvement‌:
He lead the vllm-project#1147, to
refactor AscendFusedMoE at the first time.
He shared topics about large-scale distributed inference and
reinforcement learning on vLLM-Ascend meetup on August 2nd.

@realliujiaxu
---
- Review Quality‌:
He has completed about [40+
reviews](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+commenter%3Arealliujiaxu+-author%3Arealliujiaxu+)
since September, include
vllm-project#4868 (comment),
vllm-project#2275 (comment).

- Sustained Contributions
He has completed (17
commits)[https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Arealliujiaxu+is%3Amerged],
continuously optimizing the performance of the MoE model.

- Quality Contribution‌:

Contributed the Flash Comm1 feature to the community, supporting both
eager and aclgraph execution modes, while compatible with multiple MoE
models including DeepSeek and GLM4.5.
  - vllm-project#3334
  - vllm-project#3420
  - vllm-project#3015

  co-author:
  - vllm-project#3495
  - vllm-project#4868

- Community Involvement‌:
1. Completed two major refactors, enabling vllm-ascend to evolve more
rapidly and robustly: [Linear
module](vllm-project#2867) and
[rejection
sampler](vllm-project#4975)
2. [fixed 8
bugs](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3Arealliujiaxu+is%3Amerged+bugfix+)
in graph mode, spec decoding and async scheduling.

@LCAIZJ
---
- Review Quality‌: He's been the go-to reviewer for virtually all PD
disaggregation and KV Pool related PRs, having completed [30+
reviews](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+commenter%3ALCAIZJ+is%3Aopen+-author%3ALCAIZJ+)
since May 2025. Notable examples include
[discussion_r2553887360](vllm-project#4345 (comment)),
[issuecomment-3540994801](vllm-project#4161 (comment)),
and
[discussion_r2492593988](vllm-project#3981 (comment)),
all demonstrating thorough and insightful feedback.
- Sustained and Quality Contributions: His contributions reflect a
strong grasp of both ‌vLLM‌ and ‌vLLM Ascend‌ codebases, particularly in
prefill-decode disaggregation and KV pool areas ([7 PRs
merged](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3ALCAIZJ+is%3Amerged+)).
Prefill-Decode Disaggregation: Delivered KV transfer functionality using
Mooncake TransferEngine and enabled layerwise KV transfer
vllm-project#1568
vllm-project#2602
KV Pool: Developed the foundational KV Pool infrastructure and migrated
it to the latest ADXL stack
vllm-project#2913
vllm-project#3350
- Quality Contribution‌:
vllm-project#1568
vllm-project#2602
vllm-project#2913
vllm-project#3350
- Community Involvement‌:
He actively responds to [community
issues](https://github.com/vllm-project/vllm-ascend/issues?q=is%3Aissue%20commenter%3ALCAIZJ%20is%3Aopen%20-author%3ALCAIZJ),
continuously monitors functionality and accuracy issues related to PD
disaggregation and KV Pool, and proactively delivers [bug
fixes](https://github.com/vllm-project/vllm-ascend/pulls?q=is%3Apr+author%3ALCAIZJ+is%3Amerged+bugfix).
- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants