Skip to content
Merged
2 changes: 2 additions & 0 deletions .github/workflows/scripts/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ e2e-multicard-2-cards:
estimated_time: 400
- name: tests/e2e/multicard/2-cards/test_ilama_lora_tp2.py
estimated_time: 60
- name: tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py
estimated_time: 223
# Run the test in a separate step to avoid oom
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2
estimated_time: 100
Expand Down
29 changes: 29 additions & 0 deletions tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free
from tests.e2e.singlecard.test_llama32_lora import generate_and_test
from vllm_ascend.utils import enable_custom_op

enable_custom_op()

# For hk region, we need to use the model from hf to avoid the network issue
MODEL_PATH = "vllm-ascend/Llama-3.2-3B-Instruct"


@pytest.mark.parametrize("fully_sharded_loras", [False, True])
@wait_until_npu_memory_free()
def test_llama_lora_tp2(llama32_lora_files, fully_sharded_loras):
with VllmRunner(
MODEL_PATH,
enable_lora=True,
# also test odd max_num_seqs
max_num_seqs=7,
max_model_len=1024,
max_loras=4,
tensor_parallel_size=2,
fully_sharded_loras=fully_sharded_loras,
) as vllm_model:
llm = vllm_model.model
generate_and_test(llm, llama32_lora_files)
11 changes: 3 additions & 8 deletions vllm_ascend/lora/punica_npu.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def add_expand(
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: tuple[torch.Tensor, ...] | None,
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
Expand All @@ -217,24 +216,20 @@ def add_expand(
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice

Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
Expand Down Expand Up @@ -313,7 +308,7 @@ def add_lora_linear(
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))
)
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs)
self.add_expand(y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs)

def add_lora_logits(
self,
Expand Down
80 changes: 79 additions & 1 deletion vllm_ascend/lora/utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
from vllm.config import LoRAConfig
from vllm.lora.layers import (
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
from vllm.lora.layers.utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace

from vllm_ascend.ops.linear import (
AscendColumnParallelLinear,
Expand All @@ -23,6 +28,7 @@

class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
Expand All @@ -35,6 +41,7 @@ def can_replace_layer(

class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
Expand All @@ -47,6 +54,7 @@ def can_replace_layer(

class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
Expand Down Expand Up @@ -95,10 +103,80 @@ def can_replace_layer(
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3


class AscendColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendColumnParallelLinear


class AscendMergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear


class AscendMergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3


class AscendQKVParallelLinearWithShardedLoRA(QKVParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1


class AscendRowParallelLinearWithShardedLoRA(RowParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendRowParallelLinear


def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithShardedLoRA)
Loading