Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/features/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,17 @@ vllm serve model --enable-lora --max-lora-rank 64
# Bad: unnecessarily high, wastes memory
vllm serve model --enable-lora --max-lora-rank 256
```

### Restricting LoRA to Specific Modules

The `--lora-target-modules` parameter allows you to restrict which model modules have LoRA applied at deployment time. This is useful for performance tuning when you only need LoRA on specific layers:

```bash
# Apply LoRA only to output projection layers
vllm serve model --enable-lora --lora-target-modules o_proj

# Apply LoRA to multiple specific modules
vllm serve model --enable-lora --lora-target-modules o_proj qkv_proj down_proj
```

When `--lora-target-modules` is not specified, LoRA will be applied to all supported modules in the model. This parameter accepts module suffixes (the last component of the module name), such as `o_proj`, `qkv_proj`, `gate_proj`, etc.
29 changes: 29 additions & 0 deletions tests/entrypoints/openai/test_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,32 @@ def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises):
else:
with pytest.raises(raises):
vllm_parser.parse_args(args=args)


### Tests for LoRA target modules parsing
def test_lora_target_modules_single(serve_parser):
"""Test parsing single lora-target-modules argument"""
args = serve_parser.parse_args(
args=["--enable-lora", "--lora-target-modules", "o_proj"]
)
assert args.lora_target_modules == ["o_proj"]


def test_lora_target_modules_multiple(serve_parser):
"""Test parsing multiple lora-target-modules arguments"""
args = serve_parser.parse_args(
args=[
"--enable-lora",
"--lora-target-modules",
"o_proj",
"qkv_proj",
"down_proj",
]
)
assert args.lora_target_modules == ["o_proj", "qkv_proj", "down_proj"]


def test_lora_target_modules_default_none(serve_parser):
"""Test that lora-target-modules defaults to None"""
args = serve_parser.parse_args(args=[])
assert args.lora_target_modules is None
189 changes: 189 additions & 0 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,3 +711,192 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic
torch.testing.assert_close(
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
)


def _test_target_modules(
model,
target_modules: list[str] | None,
device: str,
expected_lora: list[tuple[str, type]],
expected_no_lora: list[tuple[str, type]],
):
"""Create a LoRAModelManager and assert which modules have LoRA applied."""
LoRAModelManager(
model,
2,
2,
2,
LoRAConfig(
max_lora_rank=8,
max_cpu_loras=2,
max_loras=2,
lora_dtype=DEFAULT_DTYPE,
target_modules=target_modules,
),
device=device,
)
for module_path, lora_cls in expected_lora:
assert isinstance(model.get_submodule(module_path), lora_cls)
for module_path, lora_cls in expected_no_lora:
assert not isinstance(model.get_submodule(module_path), lora_cls)


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device):
"""Test that target_modules config restricts which modules get LoRA applied."""
_test_target_modules(
dummy_model,
["dense1"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
],
expected_no_lora=[
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
)


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device):
"""Test that multiple target_modules work correctly."""
_test_target_modules(
dummy_model,
["dense1", "dense2"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_none_uses_all(
default_vllm_config, dist_init, dummy_model, device
):
"""Test that target_modules=None uses all supported modules."""
_test_target_modules(
dummy_model,
None,
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)


@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_unsupported_modules(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
not in the model's supported LoRA target modules."""
from unittest.mock import patch

import vllm.lora.worker_manager as wm_module

lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
)

dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)

model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2

worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)

# Patch from_local_checkpoint to inject an unsupported module
original_from_checkpoint = LoRAModel.from_local_checkpoint

def patched_from_checkpoint(*args, **kwargs):
lora = original_from_checkpoint(*args, **kwargs)
lora.loras["unsupported_module"] = LoRALayerWeights(
module_name="unsupported_module",
rank=8,
lora_alpha=16,
lora_a=torch.randn(8, 10),
lora_b=torch.randn(10, 8),
)
return lora

lora_request = LoRARequest("test", 1, dummy_lora_files)
with (
patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint),
patch.object(wm_module.logger, "warning_once") as mock_warning,
):
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
found = any("unsupported_module" in str(call) for call in warning_args)
assert found, (
f"Expected warning about 'unsupported_module', got: {warning_args}"
)


@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_target_modules_restriction(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
excluded by the deployment-time target_modules restriction."""
from unittest.mock import patch

import vllm.lora.worker_manager as wm_module

# Restrict to only dense2 — adapter has dense1 which will be excluded
lora_config = LoRAConfig(
max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE,
target_modules=["dense2"],
)

dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)

model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2

worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)

lora_request = LoRARequest("test", 1, dummy_lora_files)
with patch.object(wm_module.logger, "warning_once") as mock_warning:
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
# dense1 is supported by the model but excluded by target_modules
found = any("target_modules" in str(call) for call in warning_args)
assert found, (
f"Expected warning about target_modules restriction, got: {warning_args}"
)
8 changes: 8 additions & 0 deletions vllm/config/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class LoRAConfig:
`max_loras`."""
lora_dtype: torch.dtype | LoRADType = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
target_modules: list[str] | None = None
"""Restrict LoRA to specific module suffixes (e.g., ["o_proj", "qkv_proj"]).
If None, all supported LoRA modules are used. This allows deployment-time
control over which modules have LoRA applied, useful for performance tuning."""
default_mm_loras: dict[str, str] | None = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
Expand Down Expand Up @@ -84,6 +88,10 @@ def compute_hash(self) -> str:
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
factors.append(self.enable_tower_connector_lora)
# target_modules affects which modules get LoRA applied
factors.append(
tuple(sorted(self.target_modules)) if self.target_modules else None
)

hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ class EngineArgs:
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
lora_target_modules: list[str] | None = LoRAConfig.target_modules
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
specialize_active_lora: bool = LoRAConfig.specialize_active_lora

Expand Down Expand Up @@ -1096,6 +1097,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
lora_group.add_argument(
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
)
lora_group.add_argument(
"--lora-target-modules", **lora_kwargs["target_modules"]
)
Comment thread
bhoomit marked this conversation as resolved.
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
lora_group.add_argument(
"--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
Expand Down Expand Up @@ -1773,6 +1777,7 @@ def create_engine_config(
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_dtype=self.lora_dtype,
target_modules=self.lora_target_modules,
enable_tower_connector_lora=self.enable_tower_connector_lora,
specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras
Expand Down
28 changes: 26 additions & 2 deletions vllm/lora/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,38 @@ def create_dummy_lora(
model.loras[module_name] = lora
return model

def _match_target_modules(self, module_name: str):
return any(
def _match_target_modules(self, module_name: str) -> bool:
"""Check if a module should have LoRA applied.

This method first checks if the module is in vLLM's supported LoRA
modules, then applies deployment-time restrictions based on
LoRAConfig.target_modules.

Args:
module_name: Full dot-separated module name (e.g.,
"model.layers.0.self_attn.o_proj")

Returns:
True if LoRA should be applied to this module, False otherwise.
"""
# First check if module is in vLLM's supported LoRA modules
is_supported = any(
re.match(
r".*\.{target_module}$".format(target_module=target_module), module_name
)
or target_module == module_name
for target_module in self.supported_lora_modules
)
if not is_supported:
return False

# Apply deployment-time restrictions from config
if self.lora_config.target_modules is None:
return True

# Restrict to allowed suffixes (e.g. only o_proj)
module_suffix = module_name.split(".")[-1]
return module_suffix in set(self.lora_config.target_modules)

def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
"""
Expand Down
26 changes: 26 additions & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,32 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
skip_prefixes=lora_skip_prefixes,
)

# Warn about adapter modules that will be ignored.
# Use the same suffix-matching logic as _match_target_modules:
# take the last segment of the dot-separated module name.
target_modules = self.lora_config.target_modules
for module_name in lora.loras:
module_suffix = module_name.split(".")[-1]
if module_suffix not in supported_lora_modules:
logger.warning_once(
"LoRA module '%s' in adapter '%s' is not in the "
"model's supported LoRA target modules [%s]. "
"These parameters will be ignored, which may "
"cause abnormal model behavior.",
module_name,
lora_request.lora_path,
", ".join(sorted(supported_lora_modules)),
)
elif target_modules is not None and module_suffix not in target_modules:
logger.warning_once(
"LoRA module '%s' in adapter '%s' is not in the "
"deployment-time target_modules restriction [%s]. "
"These parameters will be ignored.",
module_name,
lora_request.lora_path,
", ".join(sorted(target_modules)),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refactor @bhoomit . can we introduce a utility to do this check ? something like,

# in a utils file. 
def is_module_supported(module_name, supported_lora_modules, target_modules) -> bool:
     ...

# model_manager.py
def _match_target_modules(self, module_name: str) -> bool:
   return is_module_supported(module_name, self.supported_lora_modules, self.lora_config.target_modules)

# worker_manager.py (here)
if not is_module_supported():
    logger.warning_once("...") 

this doesn't let us differentiate between what is not-supported and what is ignored. but I think that is fine. wdyt ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do that.

If we want to have two diff warnings, we will need two utility function. And they will be used by both these files. Will update with that change.

Thanks


except FileNotFoundError as e:
# FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in
Expand Down