Skip to content

[Misc][LoRA] Add --lora-target-modules to restrict LoRA to specific modules#34984

Merged
jeejeelee merged 10 commits into
vllm-project:mainfrom
bhoomit:lora-target-modules
Mar 17, 2026
Merged

[Misc][LoRA] Add --lora-target-modules to restrict LoRA to specific modules#34984
jeejeelee merged 10 commits into
vllm-project:mainfrom
bhoomit:lora-target-modules

Conversation

@bhoomit
Copy link
Copy Markdown
Contributor

@bhoomit bhoomit commented Feb 20, 2026

Purpose

Add deployment-time control over which model modules have LoRA applied via a new --lora-target-modules CLI parameter and LoRAConfig.target_modules field.

This accepts module suffixes (e.g., o_proj, qkv_proj) and restricts LoRA application to only those modules, useful for performance tuning. When not specified, all supported LoRA modules are used (existing behavior).

Usage

vllm serve model --enable-lora --lora-target-modules o_proj qkv_proj

Changes

  • vllm/config/lora.py: Add target_modules field to LoRAConfig
  • vllm/engine/arg_utils.py: Add --lora-target-modules CLI argument
  • vllm/lora/model_manager.py: Filter modules in _match_target_modules
  • docs/features/lora.md: Document the new parameter
  • Tests: CLI arg parsing and LoRAModelManager unit tests

Benchmark: --lora-target-modules Latency Impact

Configuration

Parameter Value
Model Qwen/Qwen3-32B (bf16)
GPU NVIDIA H200 (143 GB) × 1
LoRA rank 16
vLLM version 0.16.0rc2.dev258
Torch version 2.10.0+cu129
LoRA adapter Random weights (PEFT format, all 64 layers)

Serving config: input_len=256, output_len=128, num_prompts=32, request_rate=2 req/s

Baseline = adapter only contains weights for the target modules, no --lora-target-modules flag.
With TM = full adapter (all 4 modules) + --lora-target-modules restricts at engine level.

Results with CUDA graphs + torch.compile (production mode)

TTFT (ms)

Subset Baseline With TM Δ
all 94.6 92.1 −2.6%
qkv_proj 94.2 74.3 −21.1%
o_proj 92.2 74.2 −19.5%
gate_up_proj+down_proj 95.8 86.6 −9.6%

TPOT (ms)

Subset Baseline With TM Δ
all 23.4 23.3 −0.1%
qkv_proj 23.4 20.6 −11.9%
o_proj 23.4 20.2 −13.6%
gate_up_proj+down_proj 23.4 21.8 −6.8%

Results with enforce_eager

TTFT (ms)

Subset Baseline With TM Δ
all 206.2 216.3 +4.9%
qkv_proj 214.3 123.1 −42.6%
o_proj 205.1 122.9 −40.1%
gate_up_proj+down_proj 216.9 154.8 −28.6%

TPOT (ms)

Subset Baseline With TM Δ
all 71.4 73.4 +2.7%
qkv_proj 72.3 38.1 −47.3%
o_proj 70.4 38.5 −45.3%
gate_up_proj+down_proj 74.8 50.1 −33.0%

Key takeaways

  • No overhead when all modules active (<1% noise)
  • CUDA graph mode: up to 14% TPOT reduction, 21% TTFT reduction for single-module restriction
  • Eager mode: up to 47% TPOT reduction for single-module configs
  • Adapter-level restriction is ineffective — vLLM wraps all supported modules regardless. --lora-target-modules skips wrapping entirely.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 20, 2026

Documentation preview: https://vllm--34984.org.readthedocs.build/en/34984/

@mergify mergify Bot added the documentation Improvements or additions to documentation label Feb 20, 2026
@bhoomit bhoomit force-pushed the lora-target-modules branch from 1e56a85 to 3aa9721 Compare February 20, 2026 19:50
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the --lora-target-modules CLI parameter and LoRAConfig.target_modules field, allowing users to restrict LoRA application to specific model modules at deployment time. This is a valuable feature for performance tuning. The implementation correctly integrates the new configuration into the engine and model manager. However, I have identified a critical logic bug in the vocab size validation for the logits processor and a performance/flexibility issue in the module matching logic that should be addressed.

I am having trouble creating individual review comments. Click here to see my feedback.

vllm/lora/layers/logits_processor.py (91-94)

high

The validation logic here uses a comparison chain 32000 < self.base_layer.vocab_size > 258048 which is logically equivalent to self.base_layer.vocab_size > 258048 (since 258048 > 32000). Furthermore, the error message 32000 >= vocab_size <= 258048 is mathematically confusing and likely incorrect, as it implies vocab_size must be less than or equal to 32000. If the intent is to enforce an upper bound of 258048, the logic should be simplified and the message clarified.

        if self.base_layer.vocab_size > 258048:
            raise ValueError(
                f"When using LoRA, vocab size must be <= 258048, "
                f"but found {self.base_layer.vocab_size}"
            )

vllm/lora/model_manager.py (571-572)

high

The current implementation of target_modules matching is too restrictive and inefficient.

  1. Restrictiveness: By only checking the last component of the module name (split(".")[-1]), users cannot target specific layers or sub-paths (e.g., layers.0.self_attn.o_proj). This is inconsistent with how supported_lora_modules are matched and how PEFT's target_modules usually work.
  2. Performance: Creating a set() from self.lora_config.target_modules inside this method is inefficient because _match_target_modules is called in a loop for every module in the model during initialization and warmup.

I suggest using a matching logic consistent with the is_supported check above, which also avoids the redundant set creation.

        return any(
            module_name.endswith(f".{target}") or module_name == target
            for target in self.lora_config.target_modules
        )

@bhoomit bhoomit force-pushed the lora-target-modules branch from 3aa9721 to 7a70487 Compare February 20, 2026 19:53
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 20, 2026

Hi @bhoomit, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@dcmaddix
Copy link
Copy Markdown
Contributor

Thanks @bhoomit. Looks good to me. Do we want similar logic for MoE-LoRA models?

@bhoomit
Copy link
Copy Markdown
Contributor Author

bhoomit commented Feb 20, 2026

Thanks @bhoomit. Looks good to me. Do we want similar logic for MoE-LoRA models?

It should already work, as long as the TM identifier, is last part of the layer identifier.

e.g.

x.y.o_proj -> o_proj

@bhoomit bhoomit force-pushed the lora-target-modules branch from 7a70487 to 530f1e7 Compare February 20, 2026 21:11
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 20, 2026

Hi @bhoomit, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@bhoomit bhoomit force-pushed the lora-target-modules branch from 530f1e7 to 340cee9 Compare February 20, 2026 21:17
@dcmaddix
Copy link
Copy Markdown
Contributor

Thanks @bhoomit. Looks good to me. Do we want similar logic for MoE-LoRA models?

It should already work, as long as the TM identifier, is last part of the layer identifier.

e.g.

x.y.o_proj -> o_proj

Yes that naming is the same but MoE has other target parameters, e.g., gate_up_proj and down_proj

@bhoomit
Copy link
Copy Markdown
Contributor Author

bhoomit commented Feb 20, 2026

Thanks @bhoomit. Looks good to me. Do we want similar logic for MoE-LoRA models?

It should already work, as long as the TM identifier, is last part of the layer identifier.
e.g.
x.y.o_proj -> o_proj

Yes that naming is the same but MoE has other target parameters, e.g., gate_up_proj and down_proj

@dcmaddix Yes, its tested for those two as well. It will work as expected.

@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Feb 21, 2026

Related PR and some previous discussion points are at #31452 FYI. As mentioned there, it would be ideal to accept layer indices, not only the module types I think.

@bhoomit
Copy link
Copy Markdown
Contributor Author

bhoomit commented Feb 22, 2026

Related PR and some previous discussion points are at #31452 FYI. As mentioned there, it would be ideal to accept layer indices, not only the module types I think.

Thanks @cjackal for taking a look and adding reference to related PR.

I went through the discussion.

  1. While I agree that adding layer indices (and regex support) would add more flexibility, it would also make DX/UX a bit more complicated for simple use cases. I believe "--lora-target-modules" should focus on basic use case, which might cater to large portion of users. For regex and layer indices, we can think of implementing --lora-target-parameters or --lora-targets-with-pattern. WDYT?

  2. I see that this implementation is missing warning when an adapter with "unsupported" targets are passed. As described here - [LoRA] Add --lora-target-modules to selectively apply LoRA layers #31452 (comment). I will send an update.

@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Feb 22, 2026

  1. While I agree that adding layer indices (and regex support) would add more flexibility, it would also make DX/UX a bit more complicated for simple use cases. I believe "--lora-target-modules" should focus on basic use case, which might cater to large portion of users. For regex and layer indices, we can think of implementing --lora-target-parameters or --lora-targets-with-pattern. WDYT?

I also love the agile way, we may go ahead to cover the basic usecase and extend further later. I'd mentioned the PR mostly because this feature is largely a matter of UX design (implementation-wise, all the building blocks to support selective lora targets are already there and the size of diff is also pretty small - not that tough to go at anytime) and @jeejeelee seems to have some opinion on the design.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 23, 2026

Hi @bhoomit, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@github-project-automation github-project-automation Bot moved this to In review in NVIDIA Mar 2, 2026
@mergify mergify Bot added the cpu Related to CPU backends label Mar 2, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Mar 2, 2026
@bhoomit bhoomit force-pushed the lora-target-modules branch from 8b06127 to 7e0bfcb Compare March 2, 2026 23:22
bhoomit and others added 6 commits March 2, 2026 15:23
…odules

Add deployment-time control over which model modules have LoRA applied
via a new --lora-target-modules CLI parameter and LoRAConfig.target_modules
field. This accepts module suffixes (e.g., o_proj, qkv_proj) and restricts
LoRA application to only those modules, useful for performance tuning.

When not specified, all supported LoRA modules are used (existing behavior).

Changes:
- vllm/config/lora.py: Add target_modules field to LoRAConfig
- vllm/engine/arg_utils.py: Add --lora-target-modules CLI argument
- vllm/lora/model_manager.py: Filter modules in _match_target_modules
- docs/features/lora.md: Document the new parameter
- tests: CLI arg parsing and LoRAModelManager unit tests

Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
…dules

Add a warning_once in _load_adapter when a LoRA adapter contains
modules not in the model's supported LoRA target modules. These
parameters would be silently ignored, which may cause unexpected
model behavior. The warning helps users identify misconfigured
adapters early.

Also adds a unit test that verifies the warning is emitted.

Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
…rning

Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
…educe test duplication

Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
@bhoomit bhoomit force-pushed the lora-target-modules branch from 7e0bfcb to c183eec Compare March 2, 2026 23:25
Comment thread vllm/lora/worker_manager.py Outdated
if not any(
module_name.endswith(f".{suffix}")
for suffix in supported_lora_modules
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like the if statment above should use the matching logic from _match_target_modules for consistency ?

is it true or am I missing something ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_match_target_modules checks more than the new feature "lora-target-modules".

  1. Check if the TM is in supported_lora_modules of the model
  2. Check if the TM is in "lora-target-modules" (new feature)

I see the inconsistency, working on making it better.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed the concern in latest commit.

…s and add target_modules warning

- Replace endswith() check with split('.')[-1] suffix matching in the
  unsupported module warning, consistent with _match_target_modules
- Add a second warning when an adapter module is excluded by the
  deployment-time target_modules restriction (previously silent)
- Add test_load_adapter_warns_on_target_modules_restriction to cover
  the new warning path
- Refactor _test_target_modules helper to accept expected_lora and
  expected_no_lora assertion lists, moving asserts into the helper

Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
module_name,
lora_request.lora_path,
", ".join(sorted(target_modules)),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refactor @bhoomit . can we introduce a utility to do this check ? something like,

# in a utils file. 
def is_module_supported(module_name, supported_lora_modules, target_modules) -> bool:
     ...

# model_manager.py
def _match_target_modules(self, module_name: str) -> bool:
   return is_module_supported(module_name, self.supported_lora_modules, self.lora_config.target_modules)

# worker_manager.py (here)
if not is_module_supported():
    logger.warning_once("...") 

this doesn't let us differentiate between what is not-supported and what is ignored. but I think that is fine. wdyt ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do that.

If we want to have two diff warnings, we will need two utility function. And they will be used by both these files. Will update with that change.

Thanks

…s utils

Extract shared module-matching logic into two utility functions in
vllm/lora/utils.py so both model_manager.py and worker_manager.py
reuse the same checks:

- is_supported_lora_module: regex check against model-defined modules
- is_in_target_modules: suffix check against deployment-time filter

Add unit tests in tests/lora/test_lora_utils.py.

Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
Copy link
Copy Markdown
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. There is another PR with wider filter support #31452 but there are some design questions that need answering there.

I am good with landing this PR for immediate benefits if we think the #31452 needs more thought.

Thanks @bhoomit

@jeejeelee
Copy link
Copy Markdown
Collaborator

@varun-sundar-rabindranath thank you, just add my stamp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding v1

Projects

Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants