-
-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[XPU][8/N] Fix kernel bugs in XPU LoRA and MOE LORA #34115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
89fb31a
lora fix
chaojun-zhang 3ce06af
lora fix
chaojun-zhang e1f4431
skip punica unit test for non xpu platform
chaojun-zhang dd2cc67
skip punica unit test for non xpu platform
chaojun-zhang 55a0370
Merge branch 'main' into lora_fix_1
jikunshang ca43024
Merge branch 'main' into lora_fix_1
jikunshang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,298 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from tests.lora.utils import ( | ||
| PunicaTensors, | ||
| assert_close, | ||
| generate_data, | ||
| generate_data_for_expand_nslices, | ||
| ) | ||
| from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| def torch_bgmv_expand( | ||
| inputs: torch.Tensor, | ||
| lora_b_weights: torch.Tensor, | ||
| output_tensor: torch.Tensor, | ||
| lora_indices_tensor: torch.Tensor, | ||
| add_inputs: bool = True, | ||
| ): | ||
| selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) | ||
| if len(selected_loras.shape) == 4: | ||
| selected_loras = selected_loras.squeeze(dim=1) | ||
| inputs = inputs.to(dtype=output_tensor.dtype) | ||
| outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) | ||
|
|
||
| limit = output_tensor.shape[0] | ||
| if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: | ||
| limit = 1 | ||
|
|
||
| # LoRA adapter and model may add different amounts of padding to output | ||
| common_len = min(outputs.shape[1], output_tensor.shape[1]) | ||
|
|
||
| if add_inputs: | ||
| output_tensor[:, :common_len] += outputs[:limit, :common_len] | ||
| else: | ||
| output_tensor[:, :common_len] = outputs[:limit, :common_len] | ||
|
|
||
|
|
||
| def torch_bgmv_shrink( | ||
| inputs: torch.Tensor, | ||
| lora_b_weights: torch.Tensor, | ||
| output_tensor: torch.Tensor, | ||
| lora_indices_tensor: torch.Tensor, | ||
| scaling: float = 1.0, | ||
| ): | ||
| selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) | ||
| if len(selected_loras.shape) == 4: | ||
| selected_loras = selected_loras.squeeze(dim=1) | ||
| inputs = inputs.to(dtype=output_tensor.dtype) | ||
| outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) | ||
|
|
||
| output_tensor[:, : outputs.shape[1]] = scaling * outputs[:] | ||
|
|
||
|
|
||
| def torch_bgmv_expand_slice( | ||
| inputs: torch.Tensor, | ||
| lora_b_weights: torch.Tensor, | ||
| output_tensor: torch.Tensor, | ||
| lora_indices_tensor: torch.Tensor, | ||
| slice_offset: int, | ||
| slice_size: int, | ||
| add_inputs: bool = True, | ||
| ): | ||
| selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) | ||
| inputs = inputs.to(dtype=output_tensor.dtype) | ||
| if len(selected_loras.shape) == 4: | ||
| selected_loras = selected_loras.squeeze(dim=1) | ||
| outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) | ||
|
|
||
| if add_inputs: | ||
| output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:] | ||
| else: | ||
| output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:] | ||
|
|
||
|
|
||
| def check_bgmv_shrink( | ||
| batches: int, | ||
| num_loras: int, | ||
| rank: int, | ||
| hidden_size: int, | ||
| dtype: torch.dtype, | ||
| device: str, | ||
| scaling: float, | ||
| ): | ||
| """ | ||
| Compare vllm.bgmv_shrink against a reference implementation. | ||
| """ | ||
| seq_length = 1 | ||
| data: PunicaTensors = generate_data( | ||
| batches, | ||
| hidden_size, | ||
| num_loras, | ||
| rank, | ||
| seq_length, | ||
| dtype, | ||
| "shrink", | ||
| device, | ||
| ) | ||
|
|
||
| bgmv_shrink( | ||
| data.inputs_tensor, | ||
| data.lora_weights, | ||
| data.our_out_tensor, | ||
| data.token_lora_mapping, | ||
| scaling, | ||
| ) | ||
|
|
||
| torch_bgmv_shrink( | ||
| data.inputs_tensor, | ||
| data.lora_weights, | ||
| data.ref_out_tensor, | ||
| data.token_lora_mapping, | ||
| scaling, | ||
| ) | ||
|
|
||
| data.ref_out_tensor = data.ref_out_tensor.to(torch.float32) | ||
| assert_close(data.our_out_tensor, data.ref_out_tensor) | ||
|
|
||
|
|
||
| def check_bgmv_expand( | ||
| batches: int, | ||
| num_loras: int, | ||
| rank: int, | ||
| hidden_size: int, | ||
| dtype: torch.dtype, | ||
| device: str, | ||
| add_inputs: bool, | ||
| ): | ||
| """ | ||
| Compare vllm.bgmv_expand against a reference implementation. | ||
| """ | ||
| seq_length = 1 | ||
| data: PunicaTensors = generate_data( | ||
| batches, | ||
| hidden_size, | ||
| num_loras, | ||
| rank, | ||
| seq_length, | ||
| dtype, | ||
| "expand", | ||
| device, | ||
| ) | ||
|
|
||
| bgmv_expand( | ||
| data.inputs_tensor, | ||
| data.lora_weights, | ||
| data.our_out_tensor, | ||
| data.token_lora_mapping, | ||
| add_inputs=add_inputs, | ||
| ) | ||
| torch_bgmv_expand( | ||
| data.inputs_tensor, | ||
| data.lora_weights, | ||
| data.ref_out_tensor, | ||
| data.token_lora_mapping, | ||
| add_inputs=add_inputs, | ||
| ) | ||
| assert_close(data.ref_out_tensor, data.our_out_tensor) | ||
|
|
||
|
|
||
| def check_bgmv_expand_slice( | ||
| batches: int, | ||
| num_loras: int, | ||
| rank: int, | ||
| hidden_size: int, | ||
| nslices: int, | ||
| dtype: torch.dtype, | ||
| device: str, | ||
| add_inputs: bool, | ||
| ): | ||
| """ | ||
| Compare vllm.bgmv_expand_slice against a reference implementation. | ||
| """ | ||
| seq_length = 1 | ||
| data: PunicaTensors = generate_data_for_expand_nslices( | ||
| batches, | ||
| hidden_size, | ||
| num_loras, | ||
| rank, | ||
| seq_length, | ||
| dtype, | ||
| nslices, | ||
| device, | ||
| ) | ||
|
|
||
| slice_offset = 0 | ||
| for index in range(nslices): | ||
| bgmv_expand_slice( | ||
| data.inputs_tensor, | ||
| data.lora_weights[index], | ||
| data.our_out_tensor, | ||
| data.token_lora_mapping, | ||
| slice_offset, | ||
| slice_size=hidden_size, | ||
| add_inputs=add_inputs, | ||
| ) | ||
| torch_bgmv_expand_slice( | ||
| data.inputs_tensor, | ||
| data.lora_weights[index], | ||
| data.ref_out_tensor, | ||
| data.token_lora_mapping, | ||
| slice_offset, | ||
| slice_size=hidden_size, | ||
| add_inputs=add_inputs, | ||
| ) | ||
|
|
||
| slice_offset += hidden_size | ||
| assert_close(data.ref_out_tensor, data.our_out_tensor) | ||
|
|
||
|
|
||
| # General tests params that tests for variations in all dimensions | ||
| # except hidden_size. | ||
| test_params = { | ||
| "hidden_sizes": [2049], | ||
| "batches": [4], | ||
| "num_loras": [4], | ||
| "max_ranks": [32], | ||
| } | ||
|
|
||
| DTYPES = [torch.float16, torch.bfloat16] | ||
| DEVICES = [f"xpu:{0}"] | ||
| SEED = [0] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batches", test_params["batches"]) | ||
chaojun-zhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @pytest.mark.parametrize("num_loras", test_params["num_loras"]) | ||
| @pytest.mark.parametrize("rank", test_params["max_ranks"]) | ||
| @pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"]) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| @pytest.mark.parametrize("device", DEVICES) | ||
| @pytest.mark.parametrize("seed", SEED) | ||
| @pytest.mark.parametrize("op_type", ["shrink", "expand"]) | ||
| @pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform") | ||
| def test_bgmv( | ||
| batches: int, | ||
| num_loras: int, | ||
| rank: int, | ||
| hidden_size: int, | ||
| dtype: torch.dtype, | ||
| device: str, | ||
| seed: int, | ||
| op_type: str, | ||
| ): | ||
| if op_type == "shrink": | ||
| check_bgmv_shrink( | ||
| batches=batches, | ||
| num_loras=num_loras, | ||
| rank=rank, | ||
| hidden_size=hidden_size, | ||
| dtype=dtype, | ||
| device=device, | ||
| scaling=0.5, | ||
| ) | ||
| else: | ||
| check_bgmv_expand( | ||
| batches=batches, | ||
| num_loras=num_loras, | ||
| rank=rank, | ||
| hidden_size=hidden_size, | ||
| dtype=dtype, | ||
| device=device, | ||
| add_inputs=True, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batches", test_params["batches"]) | ||
| @pytest.mark.parametrize("num_loras", test_params["num_loras"]) | ||
| @pytest.mark.parametrize("rank", test_params["max_ranks"]) | ||
| @pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"]) | ||
| @pytest.mark.parametrize("nslices", [2, 3]) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| @pytest.mark.parametrize("device", DEVICES) | ||
| @pytest.mark.parametrize("seed", SEED) | ||
| @pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform") | ||
| def test_bgmv_expand_nslices( | ||
| batches: int, | ||
| num_loras: int, | ||
| rank: int, | ||
| hidden_size: int, | ||
| nslices: int, | ||
| dtype: torch.dtype, | ||
| device: str, | ||
| seed: int, | ||
| ): | ||
| check_bgmv_expand_slice( | ||
| batches=batches, | ||
| num_loras=num_loras, | ||
| rank=rank, | ||
| hidden_size=hidden_size, | ||
| nslices=nslices, | ||
| dtype=dtype, | ||
| device=device, | ||
| add_inputs=True, | ||
| ) | ||
2 changes: 1 addition & 1 deletion
2
vllm/lora/ops/ipex_ops/__init__.py → vllm/lora/ops/xpu_ops/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink | ||
| from vllm.lora.ops.xpu_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink | ||
|
|
||
| __all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter
lora_b_weightsin thetorch_bgmv_shrinkfunction is misleading. The shrink operation involves the LoRA A weights, not B. While the tensor shapes might align correctly for theeinsumoperation, using the wrong name makes the reference implementation confusing and hard to maintain. It should be renamed tolora_a_weightsfor clarity and correctness.