Skip to content

Refactor layer initialization to use PEFT config directly#2960

Merged
BenjaminBossan merged 7 commits into
huggingface:mainfrom
BenjaminBossan:refactor-adapter-layer-init
Jan 16, 2026
Merged

Refactor layer initialization to use PEFT config directly#2960
BenjaminBossan merged 7 commits into
huggingface:mainfrom
BenjaminBossan:refactor-adapter-layer-init

Conversation

@BenjaminBossan
Copy link
Copy Markdown
Member

Description

There is a lot of code duplication in the way that adapter layers are initialized. This means that if, say, a new LoRA config argument is added, we need updates in the following places:

  1. LoraModel._create_and_replace: kwargs dict
  2. LoraModel._create_and_replace: update_layer arguments
  3. LoraModel._create_and_replace: update_layer signature
  4. Linear.__init__: signature (same for each LoRA layer type)
  5. Linear.__init__: update_layer arguments
  6. resolve_lora_variant: signature if the argument is tied to a LoRA variant

It is far too easy to forget something and it's in fact a common mistake we see, especially by outside contributors.

This PR aims at removing this duplication as much as possible. The crux is that almost all of these arguments derive from the LoraConfig. Therefore, the duplication can be reduced by passing the LoraConfig directly instead of each of its individual arguments.

As is, this change is only implemented by LoRA, which is by far the most affected by this duplication. However, an argument could be made that for consistency, the other PEFT methods should be refactored accordingly. Since AdaLoRA inherits from LoRA, it was also adapted.

Possible counter points

With this change, the used arguments are not listed explicitly anymore. But since the LoraConfig is a typed dataclass, this replacement does not stop us from knowing which arguments are available (unlike, say, using untyped kwargs). So for editing support and readability, this replacement should not hurt.

Another possible issue with this proposal is that it could easily break some functionality, as it can be easy to overlook something. Not everything is covered by tests (some of the more rarely used quantized layer types, TP layer), so it's possible to miss this.

There is a lot of code duplication in the way that adapter layers are
initialized. This means that if, say, a new LoRA config argument is
added, we need updates in the following places:

1. LoraModel._create_and_replace: kwargs dict
2. LoraModel._create_and_replace: update_layer arguments
3. LoraModel._create_and_replace: update_layer signature
4. Linear.__init__: signature (same for each LoRA layer type)
5. Linear.__init__: update_layer arguments
6. resolve_lora_variant: signature if the argument is tied to a LoRA
   variant

It is far too easy to forget something and it's in fact a common mistake
we see, especially by outside contributors.

This PR aims at removing this duplication as much as possible. The crux
is that almost all of these arguments derive from the LoraConfig.
Therefore, the duplication can be reduced by passing the LoraConfig
directly instead of each of its individual arguments.

As is, this change is only implemented by LoRA, which is by far the most
affected by this duplication. However, an argument could be made that
for consistency, the other PEFT methods should be refactored
accordingly.

Possible counter points:

With this change, the used arguments are not listed explicitly anymore.
But since the LoraConfig is a typed dataclass, this replacement does not
stop us from knowing which arguments are available (unlike, say, using
untyped kwargs). So for editing support and readability, this
replacement should not hurt.

Another possible issue with this proposal is that it could easily break
some functionality, as it can be easy to overlook something. Not
everything is covered by tests (some of the more rarely used quantized
layer types, TP layer), so it's possible to miss this.
Comment on lines +153 to +157
lora_dropout = config.lora_dropout
init_lora_weights = config.init_lora_weights
use_rslora = config.use_rslora
lora_bias = config.lora_bias
inference_mode = config.inference_mode
Copy link
Copy Markdown
Member Author

@BenjaminBossan BenjaminBossan Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I added variables like these to reduce diff noise. LMK if you'd rather remove those and just use config.lora_dropout etc. below.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +147 to +148
r: int,
lora_alpha: int,
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: r and alpha can differ from the one in the config (e.g. because of rank_pattern), thus they are not taken from the config.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review January 8, 2026 16:55
@BenjaminBossan BenjaminBossan changed the title [WIP] Refactor layer initialization Refactor layer initialization to use PEFT config directly Jan 12, 2026
Copy link
Copy Markdown
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks fine but I noticed that lora/gptq.py seems unmodified still, is this on purpose?

@BenjaminBossan
Copy link
Copy Markdown
Member Author

@githubnemo Thanks for pointing this out, GPTQ is now also updated. I ran the GPTQ tests locally (as they are GPU tests) and they passed.

Copy link
Copy Markdown
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@BenjaminBossan BenjaminBossan merged commit 91aad93 into huggingface:main Jan 16, 2026
9 of 10 checks passed
@BenjaminBossan BenjaminBossan deleted the refactor-adapter-layer-init branch January 16, 2026 17:03
@BenjaminBossan
Copy link
Copy Markdown
Member Author

@hiyouga Just a heads up, this PR simplifies the signature of multiple methods like update_layer. The idea is to directly pass the PEFT config instead of unpacking the arguments, which saves us a lot of code and makes PEFT easier to extend. For now, this is only for LoRA, but it will be extended to other PEFT methods later. In cause you're overriding these methods, please take a look at the new signature. From skimming LlamaFactory, it looks like you should be safe, but please check your tests using the PEFT main branch to be safe. If you encounter issues, please let us know.

@hiyouga
Copy link
Copy Markdown
Contributor

hiyouga commented Jan 17, 2026

@BenjaminBossan Thanks for the heads-up! All unit tests in LlamaFactory pass with the PEFT main branch, and we haven't encountered any issues.

githubnemo pushed a commit to githubnemo/peft that referenced this pull request Jan 22, 2026
The quantized adalora SVD layers were not initialized in the new
config-passing scheme and therefore raised errors in the GPU
tests.

For reproduction run

```
make tests_examples_single_gpu
```

which will yield

```
FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_4bit_adalora_causalLM - TypeError: SVDLinear4bit.__init__() missing 1 required positional argument: 'config'
FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_8bit_adalora_causalLM - TypeError: SVDLinear8bitLt.__init__() missing 1 required positional argument: 'config'
```
githubnemo pushed a commit to githubnemo/peft that referenced this pull request Jan 22, 2026
The quantized adalora SVD layers were not initialized in the new
config-passing scheme and therefore raised errors in the GPU
tests.

For reproduction run

```
make tests_examples_single_gpu
```

which will yield

```
FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_4bit_adalora_causalLM - TypeError: SVDLinear4bit.__init__() missing 1 required positional argument: 'config'
FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_8bit_adalora_causalLM - TypeError: SVDLinear8bitLt.__init__() missing 1 required positional argument: 'config'
```
githubnemo added a commit that referenced this pull request Jan 29, 2026
* Fix initialization bug introduced in #2960

The quantized adalora SVD layers were not initialized in the new
config-passing scheme and therefore raised errors in the GPU
tests.

For reproduction run

```
make tests_examples_single_gpu
```

which will yield

```
FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_4bit_adalora_causalLM - TypeError: SVDLinear4bit.__init__() missing 1 required positional argument: 'config'
FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_8bit_adalora_causalLM - TypeError: SVDLinear8bitLt.__init__() missing 1 required positional argument: 'config'
```

* Fix a small error with pytorch 2.10

Now it seems that a ValueError is raised indicating that the lora linear implementation
only supports 8 bit for now.

---------

Co-authored-by: nemo <git@ningu.net>
githubnemo pushed a commit to githubnemo/peft that referenced this pull request Feb 9, 2026
This addresses some of the errors reported by running the tests
on a single GPU machine.

I will list the error messages and a short explanation of the fix.

> `FAILED tests/test_common_gpu.py::PeftGPUCommonTests::test_lora_gptq_quantization_from_pretrained_safetensors - NameError: name 'BACKEND' is not defined`

The test was using GPTQModel without marking the test as requiring it leading to an error. This is fixed
by marking the test with `requires_gptqmodel`.

> `FAILED tests/test_custom_models.py::TestPeftCustomModel::test_only_params_are_updated[Embedding + transformers Conv1D 1 trainable_tokens-EmbConv1D-TrainableTokensConfig-config_kwargs180] - AssertionError: assert not True`
> `FAILED tests/test_custom_models.py::TestPeftCustomModel::test_disable_adapters_with_merging[Embedding + transformers Conv1D 1 trainable_tokens-EmbConv1D-TrainableTokensConfig-config_kwargs180] - AssertionError: assert not True`

This test fails because sometimes the gradients of the trainable tokens delta is 0 but only when training on CUDA,
CPU is fine.

This is a weird one and I'm not sure if this is a good fix or not. I encountered this error on two machines
(1xL40S and 4xA10G) and I was not able to pinpoint this to something particular in the environment, i.e.
PEFT version (tested v0.17 to main), transformers version (tested 4.5{5,6,7}, 5.0), CUDA version (tested 12.6, 12.8)
or torch version (tested 2.7, 2.8, 2.9, 2.10). I also set `LD_LIBRARY_PATH=` before running pytest to exclude
cuDNN libraries that come preinstalled on the EC2 instance.

Removing the ReLU in `EmbConv1DModel` as well as boosting the Conv1D weights will fix the error. Replacing
the ReLU with `Threshold(0, 0)` has the same behavior. It depends on the seed, i.e. if the initialization of
`Conv1D` is favorable the bug will not trigger.

I tried pinpointing it on `index_copy` but it is not `index_copy` by itself that is the problem. Maybe we will just
have to live with this?

> `FAILED tests/test_common_gpu.py::PeftGPUCommonTests::test_dora_ephemeral_gpu_offload_multigpu - RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_mm)`

This is caused by a bug introduced in huggingface#2960 - `ephemeral_gpu_offload` is not passed to the variant and therefore
never utilized.
githubnemo pushed a commit to githubnemo/peft that referenced this pull request Feb 9, 2026
This addresses some of the errors reported by running the tests
on a single GPU machine.

I will list the error messages and a short explanation of the fix.

> `FAILED tests/test_common_gpu.py::PeftGPUCommonTests::test_lora_gptq_quantization_from_pretrained_safetensors - NameError: name 'BACKEND' is not defined`

The test was using GPTQModel without marking the test as requiring it leading to an error. This is fixed
by marking the test with `requires_gptqmodel`.

> `FAILED tests/test_custom_models.py::TestPeftCustomModel::test_only_params_are_updated[Embedding + transformers Conv1D 1 trainable_tokens-EmbConv1D-TrainableTokensConfig-config_kwargs180] - AssertionError: assert not True`
> `FAILED tests/test_custom_models.py::TestPeftCustomModel::test_disable_adapters_with_merging[Embedding + transformers Conv1D 1 trainable_tokens-EmbConv1D-TrainableTokensConfig-config_kwargs180] - AssertionError: assert not True`

This test fails because sometimes the gradients of the trainable tokens delta is 0 but only when training on CUDA,
CPU is fine.

This is a weird one and I'm not sure if this is a good fix or not. I encountered this error on two machines
(1xL40S and 4xA10G) and I was not able to pinpoint this to something particular in the environment, i.e.
PEFT version (tested v0.17 to main), transformers version (tested 4.5{5,6,7}, 5.0), CUDA version (tested 12.6, 12.8)
or torch version (tested 2.7, 2.8, 2.9, 2.10). I also set `LD_LIBRARY_PATH=` before running pytest to exclude
cuDNN libraries that come preinstalled on the EC2 instance.

Removing the ReLU in `EmbConv1DModel` as well as boosting the Conv1D weights will fix the error. Replacing
the ReLU with `Threshold(0, 0)` has the same behavior. It depends on the seed, i.e. if the initialization of
`Conv1D` is favorable the bug will not trigger.

I tried pinpointing it on `index_copy` but it is not `index_copy` by itself that is the problem. Maybe we will just
have to live with this?

> `FAILED tests/test_common_gpu.py::PeftGPUCommonTests::test_dora_ephemeral_gpu_offload_multigpu - RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_mm)`

This is caused by a bug introduced in huggingface#2960 - `ephemeral_gpu_offload` is not passed to the variant and therefore
never utilized.
githubnemo pushed a commit to githubnemo/peft that referenced this pull request Feb 9, 2026
This addresses some of the errors reported by running the tests
on a single GPU machine.

I will list the error messages and a short explanation of the fix.

> `FAILED tests/test_common_gpu.py::PeftGPUCommonTests::test_lora_gptq_quantization_from_pretrained_safetensors - NameError: name 'BACKEND' is not defined`

The test was using GPTQModel without marking the test as requiring it leading to an error. This is fixed
by marking the test with `requires_gptqmodel`.

> `FAILED tests/test_custom_models.py::TestPeftCustomModel::test_only_params_are_updated[Embedding + transformers Conv1D 1 trainable_tokens-EmbConv1D-TrainableTokensConfig-config_kwargs180] - AssertionError: assert not True`
> `FAILED tests/test_custom_models.py::TestPeftCustomModel::test_disable_adapters_with_merging[Embedding + transformers Conv1D 1 trainable_tokens-EmbConv1D-TrainableTokensConfig-config_kwargs180] - AssertionError: assert not True`

This test fails because sometimes the gradients of the trainable tokens delta is 0 but only when training on CUDA,
CPU is fine.

This is a weird one and I'm not sure if this is a good fix or not. I encountered this error on two machines
(1xL40S and 4xA10G) and I was not able to pinpoint this to something particular in the environment, i.e.
PEFT version (tested v0.17 to main), transformers version (tested 4.5{5,6,7}, 5.0), CUDA version (tested 12.6, 12.8)
or torch version (tested 2.7, 2.8, 2.9, 2.10). I also set `LD_LIBRARY_PATH=` before running pytest to exclude
cuDNN libraries that come preinstalled on the EC2 instance.

Removing the ReLU in `EmbConv1DModel` as well as boosting the Conv1D weights will fix the error. Replacing
the ReLU with `Threshold(0, 0)` has the same behavior. It depends on the seed, i.e. if the initialization of
`Conv1D` is favorable the bug will not trigger.

I tried pinpointing it on `index_copy` but it is not `index_copy` by itself that is the problem. Maybe we will just
have to live with this?

> `FAILED tests/test_common_gpu.py::PeftGPUCommonTests::test_dora_ephemeral_gpu_offload_multigpu - RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_mm)`

This is caused by a bug introduced in huggingface#2960 - `ephemeral_gpu_offload` is not passed to the variant and therefore
never utilized.

> `FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_seq2seq_lm_training_single_gpu - AttributeError: 'T5ForConditionalGeneration' object has no attribute 'hf_device_map'`

This is caused by transformers@315dcbe45cee1489a32fc228a80502b0a150936c which disables accelerate hooks if the
device map only contains one device. I confirmed that just specifying one value moves the model to that device even
without accelerate hook invocation. I also tested having two devices (cpu + cuda:0) and in that case a device map is
present. Therefore this only needs an added `hasattr` check to be compatible with transformers v5.
githubnemo added a commit that referenced this pull request Feb 12, 2026
This addresses some of the errors reported by running the tests
on a single GPU machine.

I will list the error messages and a short explanation of the fix.

> `FAILED tests/test_common_gpu.py::PeftGPUCommonTests::test_lora_gptq_quantization_from_pretrained_safetensors - NameError: name 'BACKEND' is not defined`

The test was using GPTQModel without marking the test as requiring it leading to an error. This is fixed
by marking the test with `requires_gptqmodel`.

> `FAILED tests/test_custom_models.py::TestPeftCustomModel::test_only_params_are_updated[Embedding + transformers Conv1D 1 trainable_tokens-EmbConv1D-TrainableTokensConfig-config_kwargs180] - AssertionError: assert not True`
> `FAILED tests/test_custom_models.py::TestPeftCustomModel::test_disable_adapters_with_merging[Embedding + transformers Conv1D 1 trainable_tokens-EmbConv1D-TrainableTokensConfig-config_kwargs180] - AssertionError: assert not True`

This test fails because sometimes the gradients of the trainable tokens delta is 0 but only when training on CUDA,
CPU is fine.

This is a weird one and I'm not sure if this is a good fix or not. I encountered this error on two machines
(1xL40S and 4xA10G) and I was not able to pinpoint this to something particular in the environment, i.e.
PEFT version (tested v0.17 to main), transformers version (tested 4.5{5,6,7}, 5.0), CUDA version (tested 12.6, 12.8)
or torch version (tested 2.7, 2.8, 2.9, 2.10). I also set `LD_LIBRARY_PATH=` before running pytest to exclude
cuDNN libraries that come preinstalled on the EC2 instance.

Removing the ReLU in `EmbConv1DModel` as well as boosting the Conv1D weights will fix the error. Replacing
the ReLU with `Threshold(0, 0)` has the same behavior. It depends on the seed, i.e. if the initialization of
`Conv1D` is favorable the bug will not trigger.

I tried pinpointing it on `index_copy` but it is not `index_copy` by itself that is the problem. Maybe we will just
have to live with this?

> `FAILED tests/test_common_gpu.py::PeftGPUCommonTests::test_dora_ephemeral_gpu_offload_multigpu - RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_mm)`

This is caused by a bug introduced in #2960 - `ephemeral_gpu_offload` is not passed to the variant and therefore
never utilized.

> `FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_seq2seq_lm_training_single_gpu - AttributeError: 'T5ForConditionalGeneration' object has no attribute 'hf_device_map'`

This is caused by transformers@315dcbe45cee1489a32fc228a80502b0a150936c which disables accelerate hooks if the
device map only contains one device. I confirmed that just specifying one value moves the model to that device even
without accelerate hook invocation. I also tested having two devices (cpu + cuda:0) and in that case a device map is
present. Therefore this only needs an added `hasattr` check to be compatible with transformers v5.

Co-authored-by: nemo <git@ningu.net>
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Feb 26, 2026
The layer initialization was refactored in huggingface#2960. This introduced a bug
when initializing AdaLoRA with GPTQ layers because some parameters were
missing. This bug is now fixed.

The same bugfix is contained in huggingface#3047 but is provided here separately to
allow merging it more easily.
BenjaminBossan added a commit that referenced this pull request Mar 3, 2026
The layer initialization was refactored in #2960. This introduced a bug
when initializing AdaLoRA with GPTQ layers because some parameters were
missing. This bug is now fixed.

The same bugfix is contained in #3047 but is provided here separately to
allow merging it more easily.
BenjaminBossan added a commit that referenced this pull request Apr 17, 2026
Continuation of #2960.

This PR targets the __init__ and update_layer methods of the other relevant
PEFT methods (i.e. everything except for prompt learning). The goal is to pass
the corresponding PEFT config directly and let the layers pick out relevant
arguments, instead of having the model classes pick out the arguments and pass
them to the layers. The advantage is that this reduces code duplication.
Moreover, if, say, a new init argument is added, there is no longer the need to
update the code in multiple places just to ensure that the argument is passed
to the layers correctly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants