Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.lora_model import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.utils import WeightsMapper
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_lora_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.lora_model import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
Expand Down
4 changes: 2 additions & 2 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
MergedColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
)
from vllm.lora.lora_model import LoRAModel
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (
from vllm.lora.model_manager import (
LoRAMapping,
LoRAModel,
LoRAModelManager,
LRUCacheLoRAModelManager,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig
from vllm.lora.models import LoRAMapping
from vllm.lora.model_manager import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.v1.worker.gpu_worker import Worker

Expand Down
246 changes: 246 additions & 0 deletions vllm/lora/lora_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os

import safetensors.torch
import torch

from vllm.logger import init_logger
from vllm.lora.lora_weights import LoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import (
get_lora_id,
is_base_embeddding_weights,
is_regex_target_modules,
parse_fine_tuned_lora_name,
)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.utils.platform_utils import is_pin_memory_available

logger = init_logger(__name__)


class LoRAModel:
"""A LoRA fine-tuned model."""

def __init__(
self,
lora_model_id: int,
rank: int,
loras: dict[str, LoRALayerWeights],
) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.

"""
self.id = lora_model_id

assert lora_model_id > 0, (
f"a valid lora id should be greater than 0, got {self.id}"
)
self.rank = rank
self.loras: dict[str, LoRALayerWeights] = loras

def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.

Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)

def get_lora(self, module_name: str) -> LoRALayerWeights | None:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)

def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras

@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
tensors: dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: torch.dtype | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
if is_base_embeddding_weights(tensor_name):
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper
)
if module_name not in loras:
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper
)

if is_lora_a:
if (
"lora_embedding_A" in tensor_name
and model_vocab_size is not None
and model_vocab_size != tensor.shape[1]
):
raise RuntimeError(
f"The embedding LoRA size({tensor.shape[1]}) must be consistent"
f" with the base model's vocabulary size({model_vocab_size})."
)
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
if pin_memory:
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)

if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()

return cls(lora_model_id, peft_helper.r, loras)

@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: set[str],
peft_helper: PEFTHelper,
*,
lora_model_id: int | None = None,
device: str = "cuda",
dtype: torch.dtype | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
tensorizer_config_dict: dict | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.

Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
peft_helper: Loaded lora configuration information.
lora_model_id: LoRA model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.

Returns:
Loaded LoRA Model.
"""
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")

tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = []

def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
if is_base_embeddding_weights(lora_module):
continue
# Handle PEFT file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Case for expert lora weights
if ".experts" in module_name:
expert_idx = module_name.find(".experts")
expert_suffix = module_name[expert_idx + 1 :]
if expert_suffix not in expected_lora_modules:
unexpected_modules.append(module_name)

elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
unexpected_modules.append(module_name)

if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)

if tensorizer_config_dict:
from tensorizer import TensorDeserializer

tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
lora_tensor_path = os.path.join(
tensorizer_config.tensorizer_dir, "adapter_model.tensors"
)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(
lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserialization_kwargs,
)
check_unexpected_modules(tensors)

elif os.path.isfile(lora_tensor_path):
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it won’t error and model will be trained with A, B
# loraified. C won’t exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore
# Load tensors if there are only expected modules.
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
peft_helper.target_modules, expected_lora_modules
):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
lora_file_path = (
lora_bin_file_path
if os.path.isfile(lora_bin_file_path)
else lora_pt_file_path
)
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")

return cls.from_lora_tensors(
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
model_vocab_size=model_vocab_size,
weights_mapper=weights_mapper,
)
Loading