Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/lora/test_fused_moe_lora_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_tensor_model_parallel_world_size,
)
from vllm.lora.ops.triton_ops import fused_moe_lora
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import set_random_seed

Expand Down Expand Up @@ -244,8 +245,9 @@ def use_torch(
return torch.stack(outputs, dim=0)


DEVICE_TYPE = current_platform.device_type
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
DEVICES = [f"{DEVICE_TYPE}:{0}"]
SEED = [42]


Expand Down
298 changes: 298 additions & 0 deletions tests/lora/test_punica_xpu_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from tests.lora.utils import (
PunicaTensors,
assert_close,
generate_data,
generate_data_for_expand_nslices,
)
from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.platforms import current_platform


def torch_bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1

# LoRA adapter and model may add different amounts of padding to output
common_len = min(outputs.shape[1], output_tensor.shape[1])

if add_inputs:
output_tensor[:, :common_len] += outputs[:limit, :common_len]
else:
output_tensor[:, :common_len] = outputs[:limit, :common_len]


def torch_bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
Comment on lines +43 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parameter lora_b_weights in the torch_bgmv_shrink function is misleading. The shrink operation involves the LoRA A weights, not B. While the tensor shapes might align correctly for the einsum operation, using the wrong name makes the reference implementation confusing and hard to maintain. It should be renamed to lora_a_weights for clarity and correctness.

Suggested change
def torch_bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
def torch_bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):

selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]


def torch_bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
inputs = inputs.to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

if add_inputs:
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
else:
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]


def check_bgmv_shrink(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
scaling: float,
):
"""
Compare vllm.bgmv_shrink against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"shrink",
device,
)

bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
scaling,
)

torch_bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
scaling,
)

data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
assert_close(data.our_out_tensor, data.ref_out_tensor)


def check_bgmv_expand(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
add_inputs: bool,
):
"""
Compare vllm.bgmv_expand against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"expand",
device,
)

bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
torch_bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
assert_close(data.ref_out_tensor, data.our_out_tensor)


def check_bgmv_expand_slice(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
add_inputs: bool,
):
"""
Compare vllm.bgmv_expand_slice against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)

slice_offset = 0
for index in range(nslices):
bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.our_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
torch_bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.ref_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)

slice_offset += hidden_size
assert_close(data.ref_out_tensor, data.our_out_tensor)


# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
"hidden_sizes": [2049],
"batches": [4],
"num_loras": [4],
"max_ranks": [32],
}

DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"xpu:{0}"]
SEED = [0]


@pytest.mark.parametrize("batches", test_params["batches"])
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
@pytest.mark.parametrize("rank", test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform")
def test_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
if op_type == "shrink":
check_bgmv_shrink(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5,
)
else:
check_bgmv_expand(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True,
)


@pytest.mark.parametrize("batches", test_params["batches"])
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
@pytest.mark.parametrize("rank", test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform")
def test_bgmv_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
seed: int,
):
check_bgmv_expand_slice(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True,
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.ops.xpu_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink

__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@

logger = init_logger(__name__)

try:
import intel_extension_for_pytorch as ipex
except ImportError as e:
raise e


def bgmv_shrink(
inputs: torch.Tensor,
Expand All @@ -20,8 +15,8 @@ def bgmv_shrink(
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
ipex.llm.functional.bgmv_shrink(
inputs, lora_a_weights, output_tensor, lora_indices_tensor, scaling
torch.ops._xpu_C.bgmv_shrink(
output_tensor, inputs, lora_a_weights, lora_indices_tensor, scaling
)


Expand All @@ -32,8 +27,8 @@ def bgmv_expand(
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
ipex.llm.functional.bgmv_expand(
inputs, lora_b_weights, output_tensor, lora_indices_tensor, add_inputs
torch.ops._xpu_C.bgmv_expand(
output_tensor, inputs, lora_b_weights, lora_indices_tensor, add_inputs
)


Expand All @@ -46,10 +41,12 @@ def bgmv_expand_slice(
slice_size: int,
add_inputs: bool = True,
) -> None:
ipex.llm.functional.bgmv_expand_slice(
assert slice_size == lora_b_weights.size(-2)
assert slice_offset + slice_size <= output_tensor.size(1)
torch.ops._xpu_C.bgmv_expand_slice(
output_tensor,
inputs,
lora_b_weights,
output_tensor,
lora_indices_tensor,
slice_offset,
slice_size,
Expand Down
Loading