-
-
Notifications
You must be signed in to change notification settings - Fork 18.9k
[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend #14238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 114 commits
d993de9
4f816ed
5f0355b
edd02c5
aff94f9
adfd194
816a56c
c8a51c8
51f929d
23d4a24
47397a7
456eb37
eabc748
ac9753e
aa8b0fd
124215f
e148254
494b35e
1dbfcd9
1bb2578
ddc4cbc
48a6944
7802e84
ab5396b
fdf29d3
c2b4139
f31b7d1
87ff73e
96c3dde
4e72ede
e4d35ce
3cf0680
a8ab0c9
d73f1ce
e01d9a4
0c1bfb9
46ce7fa
5d0cc37
7590b0e
e7f75b5
fe193f7
52e3911
33a70b0
67446b2
0db19b1
a4c3b0a
9d6c388
2a9978e
b8c65bc
942ef07
56529b9
735073f
1067b50
d897f87
d6eca29
d97aae5
3ac0f63
a620e58
00d6dfd
fb0601d
89b062e
ef2ef8c
a79e19d
cc8cdf6
ad8c565
8d83065
dea7d02
0cf0eaa
6249307
af0a6a9
264d36a
b725c6a
038465c
2004369
71a1cdd
3dba9e0
f7f95e4
adfdcdb
a6d5c01
5a27785
5d15fbc
ca3d810
12f71ce
d040ee8
e696144
d110613
f8d5da2
d114377
430bae9
fb36fd6
c454062
23b14d1
27d6f70
b547271
1bb152f
af15bd1
41555d1
640420b
a02d0e9
e07d6fb
5b4ba1b
49a8102
c1be5f9
15ff074
ab036e0
b6af323
8ba2749
51d87a5
bf52dbd
8b1dae8
151fde4
8a3009d
9fb50b9
eb72ab6
c8f68d7
ed3b245
54c00c3
9f0fdbe
2012bbd
1803135
342ff8b
fc65edb
2f1da29
7daaafa
893ac04
2a0fce7
4d42844
50a06fc
ca68ce6
f4be6cc
155c2ad
317a131
b482ec8
2f26dd9
8ccbaa8
d227381
b65f60e
8a45758
987589a
bc49d0f
50e9738
4a07cf6
8cd5cb7
6282cd5
1846ef3
a006f6b
d72a86b
aff7414
e487ecb
20c5981
df67053
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import pytest | ||
| import torch | ||
|
|
||
| # Required to register the custom ops | ||
| import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import | ||
|
|
||
| N_TOKENS = [ | ||
| 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, | ||
| 131072 | ||
| ] | ||
| HIDDEN_SIZES = [128, 256, 512, 896, 1024, 2048, 4096, 8192, 8320] | ||
|
|
||
| DTYPES = [torch.float16, torch.bfloat16] | ||
| NUM_LORA = [1, 2, 4, 8, 16, 32] | ||
| RANKS = [8, 16, 32, 64, 128] | ||
|
|
||
|
|
||
| def generate_test_data(T, D, L, N, seed, dtype=torch.float32): | ||
| """ | ||
| Inputs: (All integers) | ||
| T: Total number of tokens | ||
| D: Input dim | ||
| L: LoRA Dim | ||
| N: N LoRAs | ||
|
|
||
| Outputs: | ||
| inputs: torch.Tensor - shape (T, D) | ||
| loras: torch.Tensor - shape (N, 1, L, D) | ||
| idxs: torch.Tensor - shape (T, ) - all values must be in [0, N) | ||
|
|
||
| ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T | ||
| """ | ||
| torch.manual_seed(seed) | ||
|
|
||
| inputs = torch.randn((T, D), device="xla", dtype=dtype) | ||
| loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype) | ||
| idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla") | ||
|
|
||
| ref_output = ref_bgmv(inputs, loras, idxs) | ||
| return inputs, loras, idxs, ref_output | ||
|
|
||
|
|
||
| def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): | ||
| selected_loras = loras[idxs] | ||
|
Akshat-Tripathi marked this conversation as resolved.
|
||
| if len(selected_loras.shape) == 4: | ||
| selected_loras = selected_loras.squeeze(axis=1) | ||
|
|
||
| batch_size, output_size, input_size = selected_loras.shape | ||
| return (selected_loras @ inputs.reshape( | ||
| (batch_size, input_size, 1))).reshape((batch_size, output_size)) | ||
|
|
||
|
|
||
| # Parameterize tests with various shapes and dtypes | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need a test case where torch dynamo is involved?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so, what would it add?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply as I was OOO. If you need to torch.compile a region where it uses your kernel, then you need to test it. Otherwise, feel free to leave it as is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries, I think that's covered by the E2E tests, since the model is torch.compiled |
||
| @pytest.mark.parametrize("T", N_TOKENS) | ||
| @pytest.mark.parametrize("D", HIDDEN_SIZES) | ||
| @pytest.mark.parametrize("L", RANKS) | ||
| @pytest.mark.parametrize("N", NUM_LORA) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| @pytest.mark.parametrize("op_type", ["shrink", "expand"]) | ||
| @pytest.mark.parametrize("seed", [0]) | ||
| def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): | ||
| if op_type == "expand": | ||
| D, L = L, D | ||
|
|
||
| inputs, loras, idxs, ref_output = generate_test_data( | ||
| T, D, L, N, seed, dtype) | ||
|
|
||
| # Run bgmv | ||
| output = torch.ops.xla.bgmv(inputs, loras, idxs) | ||
|
|
||
| # Make sure we have no NaNs | ||
| assert not torch.any(torch.isnan(output)) | ||
|
|
||
| # Compare with reference output | ||
| assert torch.allclose(output, ref_output, rtol=1e-3, atol=1e-3) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import vllm | ||
| from vllm.lora.request import LoRARequest | ||
|
|
||
|
|
||
| def test_lora_hotswapping(): | ||
|
Akshat-Tripathi marked this conversation as resolved.
Outdated
Akshat-Tripathi marked this conversation as resolved.
Outdated
|
||
| lora_name_template = \ | ||
| "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I ask how can we run the test locally?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, you just need to run
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, it makes sense to split it into multiple small tests. |
||
| lora_requests = [ | ||
| LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) | ||
| for i in range(1, 5) | ||
| ] | ||
|
|
||
| llm = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", | ||
|
Akshat-Tripathi marked this conversation as resolved.
Outdated
|
||
| num_scheduler_steps=1, | ||
| max_model_len=256, | ||
| max_seq_len_to_capture=256, | ||
| max_num_seqs=8, | ||
| enable_lora=True, | ||
| max_loras=2, | ||
| max_lora_rank=8) | ||
|
|
||
| prompt = "What is 1+1? \n" | ||
|
|
||
| for _ in range(10): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why we need to run 10 times
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree 10 is probably too high. I want to run twice to make sure we can do 1 full circuit of the LoRAs |
||
| for i, req in enumerate(lora_requests): | ||
| output = llm.generate(prompt, | ||
| sampling_params=vllm.SamplingParams( | ||
| max_tokens=256, temperature=0), | ||
| lora_request=req)[0].outputs[0].text | ||
| assert int(output.strip()[0]) == i + 1 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this just on my side?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This now works with the updated LoRA adapter |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -261,10 +261,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| full_lora_a_embeddings.shape[1], | ||
| -1, | ||
| ) | ||
| self.punica_wrapper.add_lora_embedding(full_output, | ||
| full_lora_a_embeddings, | ||
| self.lora_b_stacked, | ||
| add_input=True) | ||
|
|
||
| lora_output: Optional[ | ||
| torch.Tensor] = self.punica_wrapper.add_lora_embedding( | ||
| full_output, | ||
| full_lora_a_embeddings, | ||
| self.lora_b_stacked, | ||
| add_input=True) | ||
|
|
||
| if not current_platform.can_update_inplace(): | ||
|
Akshat-Tripathi marked this conversation as resolved.
|
||
| full_output = lora_output | ||
|
|
||
| return full_output.view_as(full_output_org) | ||
|
|
||
| @classmethod | ||
|
|
@@ -410,10 +417,13 @@ def apply(self, | |
| output = output.flatten(0, 1) | ||
| x = x.flatten(0, 1) | ||
|
|
||
| self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, | ||
| self.lora_b_stacked, | ||
| self.lora_bias_stacked, 1.0, | ||
| self.output_slices) | ||
| lora_output: Optional[ | ||
| torch.Tensor] = self.punica_wrapper.add_lora_linear( | ||
| output, x, self.lora_a_stacked, self.lora_b_stacked, | ||
| self.lora_bias_stacked, 1.0, self.output_slices) | ||
| if not current_platform.can_update_inplace(): | ||
| output = lora_output | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
|
||
| return output | ||
|
|
||
| @property | ||
|
|
@@ -1128,15 +1138,23 @@ def _get_logits( | |
| torch.matmul(self.embeddings_tensors, | ||
| hidden_states.T, | ||
| out=lora_logits[:-1]) | ||
| lora_logits[-1] = float("-inf") | ||
|
|
||
| neg_inf, pos_inf = current_platform.get_infinity_values( | ||
| lora_logits.dtype) | ||
|
|
||
| lora_logits[-1] = neg_inf | ||
| lora_logits = lora_logits.mT | ||
| indices_padded = self.punica_wrapper.sampler_indices_padded | ||
|
|
||
| if current_platform.is_tpu(): | ||
| indices_padded = indices_padded[:logits.size(0)] | ||
|
Akshat-Tripathi marked this conversation as resolved.
|
||
|
|
||
| lora_logits = (lora_logits.reshape( | ||
| lora_logits.shape[0] * lora_logits.shape[1], | ||
| lora_logits.shape[2], | ||
| ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), | ||
| posinf=float("inf"), | ||
| neginf=float("-inf"))) | ||
| ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, | ||
| posinf=pos_inf, | ||
| neginf=neg_inf)) | ||
|
|
||
| # HPU needs special handling to prune out dummy samples. | ||
| if current_platform.is_hpu(): | ||
|
|
@@ -1146,10 +1164,13 @@ def _get_logits( | |
| self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + | ||
| lora_logits.shape[1]] = lora_logits | ||
|
|
||
| # LogitsProcessorWithLoRA always using bgmv | ||
| self.punica_wrapper.add_lora_logits(logits, hidden_states, | ||
| self.lora_a_stacked, | ||
| self.lora_b_stacked, 1.0) | ||
| lora_output: Optional[ | ||
| torch.Tensor] = self.punica_wrapper.add_lora_logits( | ||
| logits, hidden_states, self.lora_a_stacked, | ||
| self.lora_b_stacked, 1.0) | ||
|
|
||
| if not current_platform.can_update_inplace(): | ||
| logits = lora_output | ||
|
|
||
| # Remove paddings in vocab (if any). | ||
| logits = logits[:, :self.base_layer.vocab_size] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, | ||
| bgmv_shrink) | ||
|
|
||
| __all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import torch | ||
|
|
||
| # Required to register the custom ops | ||
| import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import | ||
|
|
||
|
|
||
| def bgmv_expand(inputs: torch.Tensor, | ||
|
Akshat-Tripathi marked this conversation as resolved.
|
||
| lora_b_weights: torch.Tensor, | ||
| output_tensor: torch.Tensor, | ||
| lora_indices_tensor: torch.Tensor, | ||
| add_inputs: bool = True): | ||
| outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) | ||
| n_tokens = outputs.size(0) | ||
|
|
||
| limit = output_tensor.shape[0] | ||
| if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: | ||
| limit = 1 | ||
|
|
||
| outputs = torch.cat( | ||
|
Akshat-Tripathi marked this conversation as resolved.
|
||
| (outputs, | ||
| torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]), | ||
| device=outputs.device)), | ||
| dim=1) | ||
|
|
||
| if add_inputs: | ||
| return output_tensor + outputs[:limit, :] | ||
| else: | ||
| return outputs[:limit, :] | ||
|
|
||
|
|
||
| def bgmv_shrink(inputs: torch.Tensor, | ||
|
Akshat-Tripathi marked this conversation as resolved.
|
||
| lora_b_weights: torch.Tensor, | ||
| output_tensor: torch.Tensor, | ||
| lora_indices_tensor: torch.Tensor, | ||
| scaling: float = 1.0): | ||
|
|
||
| return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, | ||
| lora_indices_tensor) | ||
|
|
||
|
|
||
| def bgmv_expand_slice(inputs: torch.Tensor, | ||
| lora_b_weights: torch.Tensor, | ||
| output_tensor: torch.Tensor, | ||
| lora_indices_tensor: torch.Tensor, | ||
| slice_offset: int, | ||
| slice_size: int, | ||
| add_inputs: bool = True): | ||
| outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) | ||
| n_tokens = outputs.size(0) | ||
|
|
||
| outputs = torch.cat(( | ||
| torch.zeros((n_tokens, slice_offset), device=outputs.device), | ||
| outputs, | ||
| torch.zeros( | ||
| (n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)), | ||
| device=outputs.device), | ||
| ), | ||
| dim=1) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use torch.ops.xla.dynamo_set_buffer_donor_ on output_tensor?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would that work when we do += too?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so. |
||
| if add_inputs: | ||
| return output_tensor + outputs | ||
| else: | ||
| return outputs | ||
Uh oh!
There was an error while loading. Please reload this page.