Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0f730c2
lora ops repair
ZT-AIA Dec 26, 2025
a71c033
open lora test
ZT-AIA Dec 26, 2025
3d5cf7a
add lora_config
ZT-AIA Dec 27, 2025
2eaf029
Merge branch 'main' into lora_opt
ZT-AIA Dec 27, 2025
5a6d1f1
change code format
ZT-AIA Dec 27, 2025
51d38d5
repair ci
ZT-AIA Jan 3, 2026
3b50c1d
Merge branch 'vllm-project:main' into lora_opt
ZT-AIA Jan 4, 2026
81e812b
complatible vllm v0.13.0
ZT-AIA Jan 5, 2026
6791317
repair
ZT-AIA Jan 5, 2026
e0d9906
repair
ZT-AIA Jan 5, 2026
a8484c1
repair
ZT-AIA Jan 6, 2026
7956322
Merge branch 'vllm-project:main' into lora_opt
ZT-AIA Jan 6, 2026
84c4ae2
repair
ZT-AIA Jan 7, 2026
86cb6f3
Merge branch 'main' into lora_opt
ZT-AIA Jan 7, 2026
ebb5695
repair
ZT-AIA Jan 7, 2026
2bc6665
Merge branch 'lora_opt' of https://github.com/ZT-AIA/vllm-ascend into…
ZT-AIA Jan 7, 2026
a664ba0
Merge branch 'vllm-project:main' into lora_opt
ZT-AIA Jan 7, 2026
88f3c3a
repair
ZT-AIA Jan 7, 2026
8e81dc5
Merge branch 'lora_opt' of https://github.com/ZT-AIA/vllm-ascend into…
ZT-AIA Jan 7, 2026
deae86e
Merge branch 'vllm-project:main' into lora_opt
ZT-AIA Jan 7, 2026
5f07137
Merge branch 'vllm-project:main' into lora_opt
ZT-AIA Jan 7, 2026
50e0a03
Merge branch 'vllm-project:main' into lora_opt
ZT-AIA Jan 8, 2026
8f0e92a
repair
ZT-AIA Jan 8, 2026
cc668da
Merge branch 'lora_opt' of https://github.com/ZT-AIA/vllm-ascend into…
ZT-AIA Jan 8, 2026
e55e32f
Merge branch 'vllm-project:main' into lora_opt
ZT-AIA Jan 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ jobs:
# xgrammar has parameter mismatching bug, please follows: https://github.com/vllm-project/vllm-ascend/issues/5524
# pytest -sv --durations=0 tests/e2e/singlecard/test_guided_decoding.py
# torch 2.8 doesn't work with lora, fix me
#pytest -sv --durations=0 tests/e2e/singlecard/test_ilama_lora.py
pytest -sv --durations=0 tests/e2e/singlecard/test_ilama_lora.py
pytest -sv --durations=0 tests/e2e/singlecard/test_models.py
pytest -sv --durations=0 tests/e2e/singlecard/test_multistream_overlap_shared_expert.py
pytest -sv --durations=0 tests/e2e/singlecard/test_profile_execute_duration.py
Expand Down Expand Up @@ -216,7 +216,7 @@ jobs:
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_external_launcher.py
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_full_graph_mode.py
# torch 2.8 doesn't work with lora, fix me
#pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
pytest -sv --durations=0 tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py


# To avoid oom, we need to run the test in a single process.
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files):
tensor_parallel_size=2,
cudagraph_capture_sizes=[1, 2, 4, 8],
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
) as vllm_model:
output = do_sample(vllm_model.model, ilama_lora_files, lora_id=2)

Expand Down
1 change: 1 addition & 0 deletions tests/e2e/singlecard/test_ilama_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_ilama_lora(ilama_lora_files):
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
max_num_seqs=16,
enforce_eager=True,
) as vllm_model:

output1 = do_sample(vllm_model.model, ilama_lora_files, lora_id=1)
Expand Down
52 changes: 31 additions & 21 deletions vllm_ascend/lora/punica_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,10 @@
from typing import Callable, Optional, Tuple, Union

import torch

from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type

if get_ascend_device_type() == AscendDeviceType._310P:
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
else:
from vllm_ascend.lora.lora_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)

from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase

from vllm_ascend.lora.utils import refresh_all_lora_classes
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type


# The platforms that are compatible with the PyTorch-native implementation can
Expand All @@ -34,6 +23,27 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
refresh_all_lora_classes()
self.lora_config = kwargs.get("lora_config")
if get_ascend_device_type() == AscendDeviceType._310P or (
self.lora_config is not None
and self.lora_config.max_lora_rank >= 128):
from vllm.lora.ops.torch_ops import (bgmv_expand,
bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice,
sgmv_shrink)
else:
from vllm_ascend.lora.lora_ops import (bgmv_expand,
bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice,
sgmv_shrink)
self.bgmv_expand = bgmv_expand
self.bgmv_expand_slice = bgmv_expand_slice
self.bgmv_shrink = bgmv_shrink
self.sgmv_expand = sgmv_expand
self.sgmv_expand_slice = sgmv_expand_slice
self.sgmv_shrink = sgmv_shrink

def _shrink_prefill(
self,
Expand All @@ -45,7 +55,7 @@ def _shrink_prefill(
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
self.sgmv_shrink(
x,
w_t_all,
y,
Expand All @@ -60,7 +70,7 @@ def _shrink_decode(
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
self.bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)

def _expand_prefill(
self,
Expand All @@ -72,7 +82,7 @@ def _expand_prefill(
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
self.sgmv_expand(
x,
w_t_all,
y,
Expand All @@ -87,7 +97,7 @@ def _expand_decode(
w_t_all: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
self.bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)

def _expand_slice_prefill(
self,
Expand All @@ -101,7 +111,7 @@ def _expand_slice_prefill(
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
self.sgmv_expand_slice(
x,
w_t_all,
y,
Expand All @@ -120,8 +130,8 @@ def _expand_slice_decode(
y_slice_size: int,
add_inputs: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_inputs)
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices,
y_offset, y_slice_size, add_inputs)

def _apply_expand(
self,
Expand Down Expand Up @@ -346,7 +356,7 @@ def add_lora_logits(self,

indices = self.sampler_indices

bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
self.bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
self.bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)

y = y.view_as(y_org)
Loading