Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/features/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,17 @@ vllm serve model --enable-lora --max-lora-rank 64
# Bad: unnecessarily high, wastes memory
vllm serve model --enable-lora --max-lora-rank 256
```

### Restricting LoRA to Specific Modules

The `--lora-target-modules` parameter allows you to restrict which model modules have LoRA applied at deployment time. This is useful for performance tuning when you only need LoRA on specific layers:

```bash
# Apply LoRA only to output projection layers
vllm serve model --enable-lora --lora-target-modules o_proj

# Apply LoRA to multiple specific modules
vllm serve model --enable-lora --lora-target-modules o_proj qkv_proj down_proj
```

When `--lora-target-modules` is not specified, LoRA will be applied to all supported modules in the model. This parameter accepts module suffixes (the last component of the module name), such as `o_proj`, `qkv_proj`, `gate_proj`, etc.
29 changes: 29 additions & 0 deletions tests/entrypoints/openai/test_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,32 @@ def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises):
else:
with pytest.raises(raises):
vllm_parser.parse_args(args=args)


### Tests for LoRA target modules parsing
def test_lora_target_modules_single(serve_parser):
"""Test parsing single lora-target-modules argument"""
args = serve_parser.parse_args(
args=["--enable-lora", "--lora-target-modules", "o_proj"]
)
assert args.lora_target_modules == ["o_proj"]


def test_lora_target_modules_multiple(serve_parser):
"""Test parsing multiple lora-target-modules arguments"""
args = serve_parser.parse_args(
args=[
"--enable-lora",
"--lora-target-modules",
"o_proj",
"qkv_proj",
"down_proj",
]
)
assert args.lora_target_modules == ["o_proj", "qkv_proj", "down_proj"]


def test_lora_target_modules_default_none(serve_parser):
"""Test that lora-target-modules defaults to None"""
args = serve_parser.parse_args(args=[])
assert args.lora_target_modules is None
189 changes: 189 additions & 0 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,3 +711,192 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic
torch.testing.assert_close(
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
)


def _test_target_modules(
model,
target_modules: list[str] | None,
device: str,
expected_lora: list[tuple[str, type]],
expected_no_lora: list[tuple[str, type]],
):
"""Create a LoRAModelManager and assert which modules have LoRA applied."""
LoRAModelManager(
model,
2,
2,
2,
LoRAConfig(
max_lora_rank=8,
max_cpu_loras=2,
max_loras=2,
lora_dtype=DEFAULT_DTYPE,
target_modules=target_modules,
),
device=device,
)
for module_path, lora_cls in expected_lora:
assert isinstance(model.get_submodule(module_path), lora_cls)
for module_path, lora_cls in expected_no_lora:
assert not isinstance(model.get_submodule(module_path), lora_cls)


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device):
"""Test that target_modules config restricts which modules get LoRA applied."""
_test_target_modules(
dummy_model,
["dense1"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
],
expected_no_lora=[
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
)


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device):
"""Test that multiple target_modules work correctly."""
_test_target_modules(
dummy_model,
["dense1", "dense2"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_none_uses_all(
default_vllm_config, dist_init, dummy_model, device
):
"""Test that target_modules=None uses all supported modules."""
_test_target_modules(
dummy_model,
None,
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)


@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_unsupported_modules(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
not in the model's supported LoRA target modules."""
from unittest.mock import patch

import vllm.lora.worker_manager as wm_module

lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
)

dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)

model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2

worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)

# Patch from_local_checkpoint to inject an unsupported module
original_from_checkpoint = LoRAModel.from_local_checkpoint

def patched_from_checkpoint(*args, **kwargs):
lora = original_from_checkpoint(*args, **kwargs)
lora.loras["unsupported_module"] = LoRALayerWeights(
module_name="unsupported_module",
rank=8,
lora_alpha=16,
lora_a=torch.randn(8, 10),
lora_b=torch.randn(10, 8),
)
return lora

lora_request = LoRARequest("test", 1, dummy_lora_files)
with (
patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint),
patch.object(wm_module.logger, "warning_once") as mock_warning,
):
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
found = any("unsupported_module" in str(call) for call in warning_args)
assert found, (
f"Expected warning about 'unsupported_module', got: {warning_args}"
)


@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_target_modules_restriction(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
excluded by the deployment-time target_modules restriction."""
from unittest.mock import patch

import vllm.lora.worker_manager as wm_module

# Restrict to only dense2 — adapter has dense1 which will be excluded
lora_config = LoRAConfig(
max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE,
target_modules=["dense2"],
)

dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)

model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2

worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)

lora_request = LoRARequest("test", 1, dummy_lora_files)
with patch.object(wm_module.logger, "warning_once") as mock_warning:
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
# dense1 is supported by the model but excluded by target_modules
found = any("target_modules" in str(call) for call in warning_args)
assert found, (
f"Expected warning about target_modules restriction, got: {warning_args}"
)
60 changes: 60 additions & 0 deletions tests/lora/test_lora_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


from vllm.lora.utils import is_in_target_modules, is_supported_lora_module


class TestIsSupportedLoraModule:
"""Tests for is_supported_lora_module (model-definition check)."""

def test_suffix_match(self):
assert is_supported_lora_module(
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
)

def test_no_match(self):
assert not is_supported_lora_module(
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
)

def test_exact_match(self):
assert is_supported_lora_module("o_proj", ["o_proj"])

def test_regex_suffix_matching(self):
"""Regex anchors to end — partial suffix should not match."""
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", ["proj"])

def test_empty_supported_modules(self):
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", [])

def test_multiple_supported_modules(self):
supported = ["q_proj", "k_proj", "v_proj", "o_proj"]
assert is_supported_lora_module("model.layers.0.self_attn.v_proj", supported)
assert not is_supported_lora_module("model.layers.0.mlp.gate_proj", supported)


class TestIsInTargetModules:
"""Tests for is_in_target_modules (deployment-time filter)."""

def test_none_allows_all(self):
assert is_in_target_modules("model.layers.0.self_attn.o_proj", None)

def test_suffix_in_target(self):
assert is_in_target_modules(
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
)

def test_suffix_not_in_target(self):
assert not is_in_target_modules(
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
)

def test_empty_target_modules(self):
assert not is_in_target_modules("model.layers.0.self_attn.o_proj", [])

def test_exact_name_match(self):
assert is_in_target_modules("dense1", ["dense1", "dense2"])

def test_exact_name_no_match(self):
assert not is_in_target_modules("dense3", ["dense1", "dense2"])
8 changes: 8 additions & 0 deletions vllm/config/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class LoRAConfig:
`max_loras`."""
lora_dtype: torch.dtype | LoRADType = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
target_modules: list[str] | None = None
"""Restrict LoRA to specific module suffixes (e.g., ["o_proj", "qkv_proj"]).
If None, all supported LoRA modules are used. This allows deployment-time
control over which modules have LoRA applied, useful for performance tuning."""
default_mm_loras: dict[str, str] | None = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
Expand Down Expand Up @@ -84,6 +88,10 @@ def compute_hash(self) -> str:
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
factors.append(self.enable_tower_connector_lora)
# target_modules affects which modules get LoRA applied
factors.append(
tuple(sorted(self.target_modules)) if self.target_modules else None
)

hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ class EngineArgs:
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
lora_target_modules: list[str] | None = LoRAConfig.target_modules
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
specialize_active_lora: bool = LoRAConfig.specialize_active_lora

Expand Down Expand Up @@ -1107,6 +1108,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
lora_group.add_argument(
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
)
lora_group.add_argument(
"--lora-target-modules", **lora_kwargs["target_modules"]
)
Comment thread
bhoomit marked this conversation as resolved.
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
lora_group.add_argument(
"--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
Expand Down Expand Up @@ -1800,6 +1804,7 @@ def create_engine_config(
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_dtype=self.lora_dtype,
target_modules=self.lora_target_modules,
enable_tower_connector_lora=self.enable_tower_connector_lora,
specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras
Expand Down
Loading
Loading