Skip to content

[BugFix]This PR aims to fix the precision issue of the LoRA feature i…#4046

Closed
liuchenbing wants to merge 5 commits intovllm-project:mainfrom
liuchenbing:main_lora
Closed

[BugFix]This PR aims to fix the precision issue of the LoRA feature i…#4046
liuchenbing wants to merge 5 commits intovllm-project:mainfrom
liuchenbing:main_lora

Conversation

@liuchenbing
Copy link
Copy Markdown
Contributor

@liuchenbing liuchenbing commented Nov 7, 2025

vLLM version: v0.11.0
vLLM main: vllm-project/vllm

What this PR does / why we need it?

   Fix the precision issue of the LoRA feature in vllm-ascend

Does this PR introduce any user-facing change?

How was this patch tested?

pytest tests/lora/test_llama_tp.py::test_llama_lora -s

image

…n vllm-ascend.

vLLM version: v0.11.0
vLLM main: vllm-project/vllm
Signed-off-by: liuchenbing <chenliumail@163.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Nov 7, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a precision issue with LoRA features and enable bfloat16 kernels more broadly. The change in vllm_ascend/lora/punica_npu.py correctly casts the input tensor x to torch.float32 in add_lora_embedding, which aligns with the kernel's expectation and should resolve the precision problem. However, the changes in the C++ kernel files (bgmv_expand.cpp, bgmv_shrink.cpp, sgmv_expand.cpp, sgmv_shrink.cpp) introduce a critical compilation issue. While the calls to the bfloat16_t kernels are now unconditional, their definitions remain inside conditional compilation blocks (#if (__CCE_AICORE__ >= 220)). This will cause build failures on hardware with __CCE_AICORE__ < 220. Please address this by making the kernel definitions unconditional as well.

Comment on lines +359 to +361
bgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, yIn, yOut, batchSize,
numTokensPerCore, maxLoRARank, outputHiddenDim,
sliceOffset, outputFullDim);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

You've removed the conditional compilation directive for the bgmv_expand_bfloat16_t kernel call, making it unconditional. However, the kernel's definition using BGMV_EXPAND_TYPE_DECLARE(bfloat16_t) at line 346 is still inside an #if (__CCE_AICORE__ >= 220) block. This will cause a compilation error on platforms where __CCE_AICORE__ < 220, as the function will be called but not defined. You should also remove the conditional compilation around the kernel declaration to fix this.

Comment on lines +243 to +244
bgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore,
inputHiddenDim, maxLoRARank, scale);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

You've removed the conditional compilation directive for the bgmv_shrink_bfloat16_t kernel call, making it unconditional. However, the kernel's definition using BGMV_SHRINK_TYPE_DECLARE(bfloat16_t) at line 230 is still inside an #if (__CCE_AICORE__ >= 220) block. This will cause a compilation error on platforms where __CCE_AICORE__ < 220, as the function will be called but not defined. You should also remove the conditional compilation around the kernel declaration to fix this.

Comment on lines +378 to +381
sgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize,
seqLen, seqLenSize, yIn, yOut, batchSize,
numTokensPerCore, maxLoRARank, outputHiddenDim,
sliceOffset, outputFullDim);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

You've removed the conditional compilation directive for the sgmv_expand_bfloat16_t kernel call, making it unconditional. However, the kernel's definition using SGMV_EXPAND_TYPE_DECLARE(bfloat16_t) at line 361 is still inside an #if (__CCE_AICORE__ >= 220) block. This will cause a compilation error on platforms where __CCE_AICORE__ < 220, as the function will be called but not defined. You should also remove the conditional compilation around the kernel declaration to fix this.

Comment on lines +263 to +267
sgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize,
seqLen, seqLenSize,
y, batchSize,
numTokensPerCore, inputHiddenDim, maxLoRARank,
scale);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

You've removed the conditional compilation directive for the sgmv_shrink_bfloat16_t kernel call, making it unconditional. However, the kernel's definition using SGMV_SHRINK_TYPE_DECLARE(bfloat16_t) at line 246 is still inside an #if (__CCE_AICORE__ >= 220) block. This will cause a compilation error on platforms where __CCE_AICORE__ < 220, as the function will be called but not defined. You should also remove the conditional compilation around the kernel declaration to fix this.

@paulyu12
Copy link
Copy Markdown
Collaborator

paulyu12 commented Nov 7, 2025

This PR can fix 2 bugs:

  1. The accuracy issue when we add Llama-2-7b-hf LoRA e2e testcase.
  2. LoRA custom operators do not support dtype bfloat16, which is also mentioned at [Bug]: LoRA not working in v0.11.0rc0 #3668 (comment)

@liuchenbing Could you consider fixing according to the GEMINI review opinion ?

…n vllm-ascend.

vLLM version: v0.11.0
vLLM main: vllm-project/vllm
Signed-off-by: liuchenbing <chenliumail@163.com>
…n vllm-ascend.

vLLM version: v0.11.0
vLLM main: vllm-project/vllm
Signed-off-by: liuchenbing <chenliumail@163.com>
…n vllm-ascend.

vLLM version: v0.11.0
vLLM main: vllm-project/vllm
Signed-off-by: liuchenbing <chenliumail@163.com>
…n vllm-ascend.

vLLM version: v0.11.0
vLLM main: vllm-project/vllm
Signed-off-by: liuchenbing <chenliumail@163.com>
@paulyu12 paulyu12 added ready read for review ready-for-test start test by label for PR labels Nov 12, 2025
@paulyu12
Copy link
Copy Markdown
Collaborator

This PR is duplicated to #4141. We'll concentrate on that one, and this will be closed.

@paulyu12 paulyu12 closed this Nov 13, 2025
paulyu12 added a commit that referenced this pull request Jan 13, 2026
### What this PR does / why we need it?

This PR depends on PR
#4046. And only if the
latter merged, it will work.

This PR aims to solve the issue
#3240.

The new-added Llama-2-7b-hf and Qwen3-0.6B testcases will cover the
senarios that the LoRA weights are added to q_proj, v_proj, k_proj,
o_proj, gate_proj, up_proj, down_proj, embed_tokens and lm_head modules.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_llama2_lora.py
pytest -sv tests/e2e/singlecard/test_qwen3_multi_loras.py


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: paulyu12 <507435917@qq.com>
aipaes pushed a commit to aipaes/vllm-ascend that referenced this pull request Jan 15, 2026
…-project#4075)

### What this PR does / why we need it?

This PR depends on PR
vllm-project#4046. And only if the
latter merged, it will work.

This PR aims to solve the issue
vllm-project#3240.

The new-added Llama-2-7b-hf and Qwen3-0.6B testcases will cover the
senarios that the LoRA weights are added to q_proj, v_proj, k_proj,
o_proj, gate_proj, up_proj, down_proj, embed_tokens and lm_head modules.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_llama2_lora.py
pytest -sv tests/e2e/singlecard/test_qwen3_multi_loras.py


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: paulyu12 <507435917@qq.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…-project#4075)

### What this PR does / why we need it?

This PR depends on PR
vllm-project#4046. And only if the
latter merged, it will work.

This PR aims to solve the issue
vllm-project#3240.

The new-added Llama-2-7b-hf and Qwen3-0.6B testcases will cover the
senarios that the LoRA weights are added to q_proj, v_proj, k_proj,
o_proj, gate_proj, up_proj, down_proj, embed_tokens and lm_head modules.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_llama2_lora.py
pytest -sv tests/e2e/singlecard/test_qwen3_multi_loras.py


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: paulyu12 <507435917@qq.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…-project#4075)

### What this PR does / why we need it?

This PR depends on PR
vllm-project#4046. And only if the
latter merged, it will work.

This PR aims to solve the issue
vllm-project#3240.

The new-added Llama-2-7b-hf and Qwen3-0.6B testcases will cover the
senarios that the LoRA weights are added to q_proj, v_proj, k_proj,
o_proj, gate_proj, up_proj, down_proj, embed_tokens and lm_head modules.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_llama2_lora.py
pytest -sv tests/e2e/singlecard/test_qwen3_multi_loras.py


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: paulyu12 <507435917@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…-project#4075)

### What this PR does / why we need it?

This PR depends on PR
vllm-project#4046. And only if the
latter merged, it will work.

This PR aims to solve the issue
vllm-project#3240.

The new-added Llama-2-7b-hf and Qwen3-0.6B testcases will cover the
senarios that the LoRA weights are added to q_proj, v_proj, k_proj,
o_proj, gate_proj, up_proj, down_proj, embed_tokens and lm_head modules.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_llama2_lora.py
pytest -sv tests/e2e/singlecard/test_qwen3_multi_loras.py

- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: paulyu12 <507435917@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…-project#4075)

### What this PR does / why we need it?

This PR depends on PR
vllm-project#4046. And only if the
latter merged, it will work.

This PR aims to solve the issue
vllm-project#3240.

The new-added Llama-2-7b-hf and Qwen3-0.6B testcases will cover the
senarios that the LoRA weights are added to q_proj, v_proj, k_proj,
o_proj, gate_proj, up_proj, down_proj, embed_tokens and lm_head modules.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_llama2_lora.py
pytest -sv tests/e2e/singlecard/test_qwen3_multi_loras.py


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: paulyu12 <507435917@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…-project#4075)

### What this PR does / why we need it?

This PR depends on PR
vllm-project#4046. And only if the
latter merged, it will work.

This PR aims to solve the issue
vllm-project#3240.

The new-added Llama-2-7b-hf and Qwen3-0.6B testcases will cover the
senarios that the LoRA weights are added to q_proj, v_proj, k_proj,
o_proj, gate_proj, up_proj, down_proj, embed_tokens and lm_head modules.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_llama2_lora.py
pytest -sv tests/e2e/singlecard/test_qwen3_multi_loras.py

- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: paulyu12 <507435917@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…-project#4075)

### What this PR does / why we need it?

This PR depends on PR
vllm-project#4046. And only if the
latter merged, it will work.

This PR aims to solve the issue
vllm-project#3240.

The new-added Llama-2-7b-hf and Qwen3-0.6B testcases will cover the
senarios that the LoRA weights are added to q_proj, v_proj, k_proj,
o_proj, gate_proj, up_proj, down_proj, embed_tokens and lm_head modules.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_llama2_lora.py
pytest -sv tests/e2e/singlecard/test_qwen3_multi_loras.py


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: paulyu12 <507435917@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants