diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 01d9997d7d..0de186ae18 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -313,6 +313,13 @@ def _apply_adapter_patches(self): patch_peft_prep_code() + if self.cfg.adapter: + from axolotl.monkeypatch.peft_modules_to_save import ( + patch_peft_modules_to_save_kwargs, + ) + + patch_peft_modules_to_save_kwargs() + def _apply_flex_attention_patches(self): """Apply patches for flexible attention.""" if self.cfg.flex_attention: diff --git a/src/axolotl/monkeypatch/peft_modules_to_save.py b/src/axolotl/monkeypatch/peft_modules_to_save.py new file mode 100644 index 0000000000..e0727c524c --- /dev/null +++ b/src/axolotl/monkeypatch/peft_modules_to_save.py @@ -0,0 +1,57 @@ +"""Patch PEFT's AuxiliaryTrainingWrapper / ModulesToSaveWrapper so kwargs-only +forward calls work (e.g. Gemma 4 vision_tower / embed_vision in lora_modules_to_save).""" + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +_PATCHED_ATTR = "_axolotl_modules_to_save_kwargs_patched" + + +def _patched_forward(self, *args, **kwargs): + # _check_forward_args only validates len(x) vs len(adapter_names); skip when no positional x. + if args: + self._check_forward_args(*args, **kwargs) + + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters or any( + adapter not in self._adapters for adapter in self.active_adapters + ): + return self._forward_wrapped_passthrough(*args, **kwargs) + + if adapter_names is None: + return self._forward_wrapped(*args, **kwargs) + # Mixed-batch path needs positional input for sub-batch indexing; leave unchanged. + return self._mixed_batch_forward(*args, adapter_names=adapter_names, **kwargs) + + +def _patched_forward_wrapped(self, *args, **kwargs): + if not self.active_adapters: + return self._forward_wrapped_passthrough(*args, **kwargs) + return self.modules_to_save[self.active_adapters[0]](*args, **kwargs) + + +def _patched_forward_wrapped_passthrough(self, *args, **kwargs): + return self.original_module(*args, **kwargs) + + +def patch_peft_modules_to_save_kwargs() -> None: + """Apply the kwargs-compatible forward patch to PEFT. Idempotent.""" + from peft.utils.other import AuxiliaryTrainingWrapper, ModulesToSaveWrapper + + if getattr(AuxiliaryTrainingWrapper, _PATCHED_ATTR, False): + return + + AuxiliaryTrainingWrapper.forward = _patched_forward + ModulesToSaveWrapper._forward_wrapped = _patched_forward_wrapped + ModulesToSaveWrapper._forward_wrapped_passthrough = ( + _patched_forward_wrapped_passthrough + ) + + setattr(AuxiliaryTrainingWrapper, _PATCHED_ATTR, True) + LOG.debug( + "Patched peft.AuxiliaryTrainingWrapper / ModulesToSaveWrapper to accept " + "kwargs-only forward calls (enables full-FT of modules called with " + "keyword args, e.g. Gemma 4 vision_tower/embed_vision)" + ) diff --git a/tests/monkeypatch/test_peft_modules_to_save.py b/tests/monkeypatch/test_peft_modules_to_save.py new file mode 100644 index 0000000000..d4f38b7569 --- /dev/null +++ b/tests/monkeypatch/test_peft_modules_to_save.py @@ -0,0 +1,148 @@ +"""Tests for axolotl.monkeypatch.peft_modules_to_save.""" + +import pytest +import torch +import torch.nn as nn + +pytest.importorskip("peft") + + +@pytest.fixture(scope="module") +def patched(): + from axolotl.monkeypatch.peft_modules_to_save import ( + patch_peft_modules_to_save_kwargs, + ) + + patch_peft_modules_to_save_kwargs() + yield + + +@pytest.fixture +def wrap_module(patched): + """Wrap ``module`` like PEFT wraps a ``modules_to_save`` entry.""" + from peft.utils.other import ModulesToSaveWrapper + + def _wrap(module: nn.Module) -> ModulesToSaveWrapper: + w = ModulesToSaveWrapper(module, "default") + w.set_adapter("default") + return w + + return _wrap + + +def test_kwargs_only_forward(wrap_module): + """Gemma 4 vision_tower shape: self.vision_tower(pixel_values=...).""" + + class KwargsOnly(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(8, 8) + + def forward(self, *, pixel_values): + return self.lin(pixel_values) + + w = wrap_module(KwargsOnly()) + out = w(pixel_values=torch.randn(2, 8)) + assert out.shape == (2, 8) + + +def test_positional_only_forward(wrap_module): + """embed_tokens shape: self.embed_tokens(input_ids).""" + + class PosOnly(nn.Module): + def __init__(self): + super().__init__() + self.emb = nn.Embedding(100, 8) + + def forward(self, x): + return self.emb(x) + + w = wrap_module(PosOnly()) + out = w(torch.tensor([1, 2, 3])) + assert out.shape == (3, 8) + + +def test_mixed_args_kwargs_forward(wrap_module): + """Module taking both positional and keyword args.""" + + class Mixed(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(8, 8) + + def forward(self, x, scale=1.0): + return self.lin(x) * scale + + w = wrap_module(Mixed()) + out = w(torch.randn(2, 8), scale=2.0) + assert out.shape == (2, 8) + + +def test_kwargs_only_disabled_adapter(wrap_module): + """Passthrough branch (enable_adapters(False)) must also accept kwargs-only.""" + + class KwargsOnly(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(8, 8) + + def forward(self, *, pixel_values): + return self.lin(pixel_values) + + w = wrap_module(KwargsOnly()) + w.enable_adapters(False) + + out = w(pixel_values=torch.randn(2, 8)) + assert out.shape == (2, 8) + + +def test_trainable_tokens_wrapper_not_broken(patched): + """Base-class forward replacement must not break TrainableTokensWrapper + (positional embed_tokens path must still work).""" + try: + from peft.utils.other import TrainableTokensWrapper + except ImportError: + pytest.skip("TrainableTokensWrapper not available in installed PEFT") + + try: + w = TrainableTokensWrapper( + nn.Embedding(100, 8), "default", token_indices=[0, 1, 2] + ) + except TypeError as exc: + pytest.skip(f"TrainableTokensWrapper init signature changed: {exc}") + w.set_adapter("default") + + out = w(torch.tensor([0, 1, 2, 5])) + assert out.shape == (4, 8) + + +def test_mixed_batch_kwargs_only_raises(wrap_module): + """Document current behavior: adapter_names without positional input is + unsupported (mixed-batch path is deliberately not kwargs-patched).""" + + class KwargsOnly(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(8, 8) + + def forward(self, *, pixel_values): + return self.lin(pixel_values) + + w = wrap_module(KwargsOnly()) + with pytest.raises(TypeError): + w(pixel_values=torch.randn(2, 8), adapter_names=["default", "default"]) + + +def test_patch_is_idempotent(): + """Applying the patch twice is a no-op.""" + from peft.utils.other import AuxiliaryTrainingWrapper + + from axolotl.monkeypatch.peft_modules_to_save import ( + patch_peft_modules_to_save_kwargs, + ) + + patch_peft_modules_to_save_kwargs() + first = AuxiliaryTrainingWrapper.forward + patch_peft_modules_to_save_kwargs() + second = AuxiliaryTrainingWrapper.forward + assert first is second, "patch must be idempotent"