From 7fa47fc39ed67c8cc2ff66a81ac1c1c9007f178b Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Mon, 9 Feb 2026 18:59:24 -0800 Subject: [PATCH 1/9] [Bugfix][LoRA] Fix issue when enable LoRA + tp2 + fully_sharded_loras Signed-off-by: paulyu12 <507435917@qq.com> --- .../2-cards/test_llama32_lora_tp2.py | 35 ++++++++ vllm_ascend/lora/punica_npu.py | 11 +-- vllm_ascend/lora/utils.py | 80 ++++++++++++++++++- 3 files changed, 117 insertions(+), 9 deletions(-) create mode 100755 tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py mode change 100644 => 100755 vllm_ascend/lora/punica_npu.py mode change 100644 => 100755 vllm_ascend/lora/utils.py diff --git a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py new file mode 100755 index 00000000000..614e80204c3 --- /dev/null +++ b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +import vllm +import vllm.config +from vllm.lora.request import LoRARequest +from unittest.mock import patch + +from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free +from tests.e2e.singlecard.test_llama32_lora import PROMPT_TEMPLATE, EXPECTED_LORA_OUTPUT, EXPECTED_BASE_MODEL_OUTPUT, do_sample, generate_and_test +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +# For hk region, we need to use the model from hf to avoid the network issue +MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct" + + +@patch.dict("os.environ", {"VLLM_USE_MODELSCOPE": "False"}) +@pytest.mark.parametrize("fully_sharded_loras", [False, True]) +@wait_until_npu_memory_free() +def test_llama_lora_tp2(llama32_lora_files, fully_sharded_loras): + with VllmRunner( + MODEL_PATH, + enable_lora=True, + # also test odd max_num_seqs + max_num_seqs=7, + max_model_len=1024, + max_loras=4, + tensor_parallel_size=2, + fully_sharded_loras=fully_sharded_loras, + ) as vllm_model: + llm = vllm_model.model + generate_and_test(llm, llama32_lora_files) diff --git a/vllm_ascend/lora/punica_npu.py b/vllm_ascend/lora/punica_npu.py old mode 100644 new mode 100755 index 885c0765705..1ae9ac97d5d --- a/vllm_ascend/lora/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -205,7 +205,6 @@ def add_expand( y: torch.Tensor, x: tuple[torch.Tensor, ...] | torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: tuple[torch.Tensor, ...] | None, output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, @@ -217,24 +216,20 @@ def add_expand( Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): - bias's weight output_slices (Tuple[int, ...]): Every slice's size + offset_start (int): The starting position of y, defaults to 0 add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start - if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): self._apply_expand( y, @@ -313,7 +308,7 @@ def add_lora_linear( torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices)) ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs) + self.add_expand(y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs) def add_lora_logits( self, diff --git a/vllm_ascend/lora/utils.py b/vllm_ascend/lora/utils.py old mode 100644 new mode 100755 index a0178560303..cb99a790ede --- a/vllm_ascend/lora/utils.py +++ b/vllm_ascend/lora/utils.py @@ -9,8 +9,13 @@ QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA ) -from vllm.lora.layers.utils import _not_fully_sharded_can_replace +from vllm.lora.layers.utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace from vllm_ascend.ops.linear import ( AscendColumnParallelLinear, @@ -23,6 +28,7 @@ class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @classmethod + @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, @@ -35,6 +41,7 @@ def can_replace_layer( class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): @classmethod + @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, @@ -47,6 +54,7 @@ def can_replace_layer( class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA): @classmethod + @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, @@ -95,6 +103,71 @@ def can_replace_layer( return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3 +class AscendColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithShardedLoRA): + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None = None, + ) -> bool: + return type(source_layer) is AscendColumnParallelLinear + + +class AscendMergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithShardedLoRA): + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None = None, + ) -> bool: + return type(source_layer) is AscendMergedColumnParallelLinear + + +class AscendMergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithShardedLoRA): + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None = None, + ) -> bool: + return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3 + + +class AscendQKVParallelLinearWithShardedLoRA(QKVParallelLinearWithShardedLoRA): + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None = None, + ) -> bool: + return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1 + + +class AscendRowParallelLinearWithShardedLoRA(RowParallelLinearWithShardedLoRA): + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None = None, + ) -> bool: + return type(source_layer) is AscendRowParallelLinear + + def refresh_all_lora_classes(): vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA) @@ -102,3 +175,8 @@ def refresh_all_lora_classes(): vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithShardedLoRA) + vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithShardedLoRA) + vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithShardedLoRA) + vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithShardedLoRA) + vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithShardedLoRA) From 6415dce5ba6e5f16dbabc1eed2d00adf25c4366e Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Mon, 9 Feb 2026 19:54:31 -0800 Subject: [PATCH 2/9] [Bugfix][LoRA] lint Signed-off-by: paulyu12 <507435917@qq.com> --- vllm_ascend/lora/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/lora/utils.py b/vllm_ascend/lora/utils.py index cb99a790ede..341a3ab4ba1 100755 --- a/vllm_ascend/lora/utils.py +++ b/vllm_ascend/lora/utils.py @@ -4,16 +4,16 @@ from vllm.config import LoRAConfig from vllm.lora.layers import ( ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA, - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, - QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA ) from vllm.lora.layers.utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace From d185abad6bd9b29134f97a733a4102627347c836 Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Mon, 9 Feb 2026 22:15:32 -0800 Subject: [PATCH 3/9] [Bugfix][LoRA] enable a new testcase Signed-off-by: paulyu12 <507435917@qq.com> --- .github/workflows/scripts/config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index a7f0cf490c5..4d6c7bb7658 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -95,6 +95,8 @@ e2e-multicard-2-cards: estimated_time: 400 - name: tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py estimated_time: 60 + - name: tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py + estimated_time: 223 # Run the test in a separate step to avoid oom - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2 estimated_time: 100 From b009b028e214c4fa6dc7e62ccfff40dccbe430ed Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Wed, 25 Feb 2026 23:02:01 -0800 Subject: [PATCH 4/9] [Bugfix][LoRA] fix Signed-off-by: paulyu12 <507435917@qq.com> --- tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py index 614e80204c3..a8dc62b0e5b 100755 --- a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py +++ b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py @@ -14,7 +14,7 @@ enable_custom_op() # For hk region, we need to use the model from hf to avoid the network issue -MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct" +MODEL_PATH = "vllm-ascend/Llama-3.2-3B-Instruct" @patch.dict("os.environ", {"VLLM_USE_MODELSCOPE": "False"}) From cdaa760a857e1ad7d994dddefeefb29b4f01831a Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Wed, 25 Feb 2026 23:03:22 -0800 Subject: [PATCH 5/9] [Bugfix][LoRA] fix Signed-off-by: paulyu12 <507435917@qq.com> --- tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py index a8dc62b0e5b..e6cdfde3ece 100755 --- a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py +++ b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py @@ -17,7 +17,6 @@ MODEL_PATH = "vllm-ascend/Llama-3.2-3B-Instruct" -@patch.dict("os.environ", {"VLLM_USE_MODELSCOPE": "False"}) @pytest.mark.parametrize("fully_sharded_loras", [False, True]) @wait_until_npu_memory_free() def test_llama_lora_tp2(llama32_lora_files, fully_sharded_loras): From ab7ca59f470729b392fc20a7827e276c029cce08 Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Fri, 27 Feb 2026 06:35:09 -0800 Subject: [PATCH 6/9] [Bugfix][LoRA] fix Signed-off-by: paulyu12 <507435917@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1f0ff0bfe10..8a416414b9a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2177,6 +2177,8 @@ def _dummy_run( self.lora_config, num_scheduled_tokens, num_sampled_tokens, + remove_lora, + num_active_loras, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens From 461677ae9223eded8958ab392a1b687ff62c24ab Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Tue, 3 Mar 2026 16:25:32 +0800 Subject: [PATCH 7/9] [bugfix][LoRA] fix Signed-off-by: paulyu12 <507435917@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8a416414b9a..1f0ff0bfe10 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2177,8 +2177,6 @@ def _dummy_run( self.lora_config, num_scheduled_tokens, num_sampled_tokens, - remove_lora, - num_active_loras, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens From 0dada74fc0cb540f3ef8cde0e56a01ed9317afa9 Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Tue, 10 Mar 2026 10:52:36 +0800 Subject: [PATCH 8/9] [bugfix][LoRA] lint Signed-off-by: paulyu12 <507435917@qq.com> --- tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py index e6cdfde3ece..f3b9ebc651e 100755 --- a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py +++ b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py @@ -2,13 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -import vllm -import vllm.config -from vllm.lora.request import LoRARequest -from unittest.mock import patch - from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free -from tests.e2e.singlecard.test_llama32_lora import PROMPT_TEMPLATE, EXPECTED_LORA_OUTPUT, EXPECTED_BASE_MODEL_OUTPUT, do_sample, generate_and_test +from tests.e2e.singlecard.test_llama32_lora import ( + generate_and_test +) from vllm_ascend.utils import enable_custom_op enable_custom_op() From 055e26e2f5113c0d9eee1e3dab3dbfad075ff6f4 Mon Sep 17 00:00:00 2001 From: paulyu12 <507435917@qq.com> Date: Tue, 10 Mar 2026 11:09:11 +0800 Subject: [PATCH 9/9] [bugfix][LoRA] lint Signed-off-by: paulyu12 <507435917@qq.com> --- tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py index f3b9ebc651e..06bb5065362 100755 --- a/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py +++ b/tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py @@ -3,9 +3,7 @@ import pytest from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free -from tests.e2e.singlecard.test_llama32_lora import ( - generate_and_test -) +from tests.e2e.singlecard.test_llama32_lora import generate_and_test from vllm_ascend.utils import enable_custom_op enable_custom_op()