Skip to content

[Bug] Fix compile error for swap_blocks_batch in CUDA 13#38915

Merged
simon-mo merged 2 commits into
mainfrom
wentao-fix-compile-error
Apr 3, 2026
Merged

[Bug] Fix compile error for swap_blocks_batch in CUDA 13#38915
simon-mo merged 2 commits into
mainfrom
wentao-fix-compile-error

Conversation

@yewentao256

Copy link
Copy Markdown
Member

Purpose

Originally

[1/3] Building CUDA object CMakeFiles/_C.dir/csrc/cache_kernels.cu.o
FAILED: [code=255] CMakeFiles/_C.dir/csrc/cache_kernels.cu.o 
ccache /usr/local/cuda-13.0/bin/nvcc -forward-unknown-to-host-compiler -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1 -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_NVSHMEM -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -I/home/yewentao256/vllm-source/csrc -I/home/yewentao256/vllm-source/cmake-build-release/_deps/cutlass-src/include -I/home/yewentao256/vllm-source/cmake-build-release/_deps/cutlass-src/tools/util/include -isystem /home/yewentao256/.local/share/uv/python/cpython-3.12.13-linux-x86_64-gnu/include/python3.12 -isystem /home/yewentao256/.venv/lib/python3.12/site-packages/torch/include -isystem /home/yewentao256/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda-13.0/include -DONNX_NAMESPACE=onnx_c2 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC --expt-relaxed-constexpr -DENABLE_FP8 --threads=8 --compress-mode=size -gencode arch=compute_90,code=sm_90 -MD -MT CMakeFiles/_C.dir/csrc/cache_kernels.cu.o -MF CMakeFiles/_C.dir/csrc/cache_kernels.cu.o.d -x cu -c /home/yewentao256/vllm-source/csrc/cache_kernels.cu -o CMakeFiles/_C.dir/csrc/cache_kernels.cu.o
/home/yewentao256/vllm-source/csrc/cache_kernels.cu(115): error: argument of type "size_t *" (aka "unsigned long *") is incompatible with parameter of type "CUstream" (aka "CUstream_st *")
        static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
                                                      ^

/home/yewentao256/vllm-source/csrc/cache_kernels.cu(116): error: too many arguments in function call
        static_cast<CUstream>(stream));
        ^

2 errors detected in the compilation of "/home/yewentao256/vllm-source/csrc/cache_kernels.cu".
ninja: build stopped: subcommand failed.

Now is fixed

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 3, 2026
@mergify mergify Bot added nvidia bug Something isn't working labels Apr 3, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces conditional compilation to handle an API change in cuMemcpyBatchAsync for CUDA 13.0, which removed the fail_idx parameter. The code now provides a specific implementation for CUDA 13.0 while maintaining the previous logic for older versions (12.8+). I have no feedback to provide.

Comment thread csrc/cache_kernels.cu Outdated
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this was there before, but we should not need to const cast these. Perhaps we should remove the constness of dst_data in the declaration above

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, fixed, thanks!

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Apr 3, 2026
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@simon-mo simon-mo merged commit 062f1a2 into main Apr 3, 2026
141 of 142 checks passed
@simon-mo simon-mo deleted the wentao-fix-compile-error branch April 3, 2026 23:56
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Apr 3, 2026
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Apr 6, 2026
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
wangxiyuan added a commit to vllm-project/vllm-ascend that referenced this pull request Apr 21, 2026
### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
weijinqian0 pushed a commit to weijinqian0/vllm-ascend that referenced this pull request Apr 21, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
anning-2026 pushed a commit to anning-2026/vllm-ascend that referenced this pull request Apr 21, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
guxin108 pushed a commit to guxin108/vllm-ascend that referenced this pull request Apr 24, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: guxin108 <1252896542@qq.com>
zouyida2052 pushed a commit to zouyida2052/vllm-ascend that referenced this pull request Apr 28, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
yangzhe-2026 pushed a commit to yangzhe-2026/vllm-ascend that referenced this pull request May 6, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
PiratePai pushed a commit to PiratePai/vllm-ascend that referenced this pull request May 7, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: PiratePai <416932041@qq.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
yangzhe-2026 pushed a commit to yangzhe-2026/vllm-ascend that referenced this pull request May 10, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: yangzhe-2026 <yangzhe@isrc.iscas.ac.cn>
ZhuQi-seu pushed a commit to ZhuQi-seu/vllm-ascend that referenced this pull request May 12, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: ZhuQi-seu <zhuqi12@huawei.com>
nanxingMy pushed a commit to nanxingMy/vllm-ascend that referenced this pull request May 15, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: nanxing <1014662416@qq.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…ect#38915)

Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants