Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions tests/entrypoints/openai/test_lora_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@
{"r": 1024},
"is greater than max_lora_rank",
),
(
"test_bias",
{"bias": "all"},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
(
"test_modules_to_save",
Expand Down
5 changes: 0 additions & 5 deletions tests/lora/test_peft_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
{"r": 1024},
"is greater than max_lora_rank",
),
(
"test_bias",
{"bias": "all"},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
(
"test_modules_to_save",
Expand Down
15 changes: 2 additions & 13 deletions tests/lora/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class LoRANameParserTestConfig(NamedTuple):
name: str
module_name: str
is_lora_a: bool
is_bias: bool
weights_mapper: Optional[WeightsMapper] = None


Expand All @@ -37,44 +36,37 @@ def test_parse_fine_tuned_lora_name_valid():
"base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens",
True,
False,
),
LoRANameParserTestConfig(
"base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens",
False,
False,
),
LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj",
True,
False,
),
LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj",
False,
False,
),
LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.layers.9.mlp.down_proj",
True,
False,
),
LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.layers.9.mlp.down_proj",
False,
False,
),
# Test with WeightsMapper
LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj",
True,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}
),
Expand All @@ -83,7 +75,6 @@ def test_parse_fine_tuned_lora_name_valid():
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj",
False,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}
),
Expand All @@ -92,7 +83,6 @@ def test_parse_fine_tuned_lora_name_valid():
"model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj",
True,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}
),
Expand All @@ -101,14 +91,13 @@ def test_parse_fine_tuned_lora_name_valid():
"model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj",
False,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}
),
),
]
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name(
for name, module_name, is_lora_a, weights_mapper in fixture:
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(
name, weights_mapper
)

Expand Down
8 changes: 1 addition & 7 deletions vllm/config/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ class LoRAConfig:
per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities
in alphabetic order."""
bias_enabled: bool = Field(
default=False,
deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
)
"""[DEPRECATED] Enable bias for LoRA adapters. This option will be
removed in v0.12.0."""

def compute_hash(self) -> str:
"""
Expand All @@ -96,7 +90,7 @@ def compute_hash(self) -> str:
factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size)
factors.append(self.bias_enabled)

hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

Expand Down
3 changes: 0 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ class EngineArgs:
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
Expand Down Expand Up @@ -916,7 +915,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action=argparse.BooleanOptionalAction,
help="If True, enable handling of LoRA adapters.",
)
lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"])
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
lora_group.add_argument(
Expand Down Expand Up @@ -1515,7 +1513,6 @@ def create_engine_config(

lora_config = (
LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras,
Expand Down
1 change: 0 additions & 1 deletion vllm/lora/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def set_lora(
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
"""Overwrites lora tensors at index."""
...
Expand Down
40 changes: 2 additions & 38 deletions vllm/lora/layers/base_linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, cast
from typing import Optional

import torch
from transformers import PretrainedConfig
Expand Down Expand Up @@ -29,7 +29,6 @@ def __init__(self, base_layer: LinearBase):
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...]
self.output_size: int
self.n_slices: int
Expand Down Expand Up @@ -86,38 +85,19 @@ def create_lora_weights(
)
for _ in range(self.n_slices)
)
if lora_config.bias_enabled:
lora_bias_out_size = lora_b_out_size
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_bias_out_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.output_slices = (self.lora_b_stacked[0].shape[2],)

def reset_lora(self, index: int):
for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0
if self.lora_config.bias_enabled:
# Make mypy happy
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
self.lora_bias_stacked[s_index][index] = 0

def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
Expand All @@ -131,23 +111,13 @@ def set_lora(
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)

self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
)
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
if lora_bias is not None:
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, : lora_bias.shape[0]].copy_(
lora_bias, non_blocking=True
)

def apply(
self, x: torch.Tensor, bias: Optional[torch.Tensor] = None
Expand All @@ -162,13 +132,7 @@ def apply(
x = x.flatten(0, 1)

lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear(
output,
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.lora_bias_stacked,
1.0,
self.output_slices,
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
)
if not current_platform.can_update_inplace():
output = lora_output
Expand Down
67 changes: 1 addition & 66 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union, cast
from typing import Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -32,8 +32,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
== len(layer.lora_b_stacked)
== len(layer.output_slices)
)
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)

output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)

Expand Down Expand Up @@ -61,7 +59,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True,
Expand Down Expand Up @@ -122,16 +119,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
lora_b = lora_b[start_idx:end_idx, :]
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
# TODO: Fix the slicing logic of bias.
if bias is None:
return bias
shard_size = self.output_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias

def forward(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
Expand Down Expand Up @@ -238,17 +225,6 @@ def create_lora_weights(
)
for output_size in self.output_slices
)
if lora_config.bias_enabled:
self.lora_bias_stacked = tuple(
torch.zeros(
max_loras,
1,
output_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for output_size in self.output_slices
)

def slice_lora_a(
self, lora_a: list[Union[torch.Tensor, None]]
Expand All @@ -268,31 +244,18 @@ def slice_lora_b(
]
return sliced_lora_b

def slice_bias(
self, bias: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)
):
if (bias_i := bias[i]) is not None:
bias[i] = bias_i[shard_size * shard_id : shard_size * (shard_id + 1)]
return bias

def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)

if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)

for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None:
Expand All @@ -304,16 +267,6 @@ def set_lora(
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
].copy_(lora_b_i, non_blocking=True)

if lora_bias is not None:
self.lora_bias_stacked = cast(
tuple[torch.Tensor, ...], self.lora_bias_stacked
)
for i in range(self.n_slices):
if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index, 0, : lora_bias_i.shape[0]].copy_(
lora_bias_i, non_blocking=True
)

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
Expand Down Expand Up @@ -380,24 +333,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
bias_q = bias[
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
* (self.q_shard_id + 1)
]
k_offset = self.q_proj_total_size
bias_k = bias[
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
]
v_offset = k_offset + self.kv_proj_total_size
bias_v = bias[
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)
]
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
return bias

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
Expand Down
Loading