diff --git a/tests/test_peft_weight_converter_compat.py b/tests/test_peft_weight_converter_compat.py new file mode 100644 index 0000000000..62f17f0f45 --- /dev/null +++ b/tests/test_peft_weight_converter_compat.py @@ -0,0 +1,259 @@ +import importlib.util +import inspect +import sys +import threading +import types +from pathlib import Path + +import pytest + + +REPO_ROOT = Path(__file__).resolve().parents[1] +IMPORT_FIXES = REPO_ROOT / "unsloth" / "import_fixes.py" + + +def _load_patch_function(): + spec = importlib.util.spec_from_file_location( + "_unsloth_import_fixes_under_test", IMPORT_FIXES + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.patch_peft_weight_converter_compatibility + + +def _install_fake_peft(twc_namespace): + peft_pkg = types.ModuleType("peft") + peft_pkg.__path__ = [] + peft_utils = types.ModuleType("peft.utils") + peft_utils.__path__ = [] + twc = types.ModuleType("peft.utils.transformers_weight_conversion") + for k, v in twc_namespace.items(): + setattr(twc, k, v) + peft_utils.transformers_weight_conversion = twc + sys.modules["peft"] = peft_pkg + sys.modules["peft.utils"] = peft_utils + sys.modules["peft.utils.transformers_weight_conversion"] = twc + return twc + + +@pytest.fixture(autouse = True) +def _restore_peft_modules(): + saved = { + k: sys.modules.get(k) + for k in ( + "peft", + "peft.utils", + "peft.utils.transformers_weight_conversion", + ) + } + yield + for k, v in saved.items(): + if v is None: + sys.modules.pop(k, None) + else: + sys.modules[k] = v + + +class _LegacyConverter: + def __init__(self, source_patterns, target_patterns, operations): + self.source_patterns = source_patterns + self.target_patterns = target_patterns + self.operations = operations + self.distributed_operation = None + self.quantization_operation = None + + +class _ModernConverter: + def __init__( + self, + source_patterns, + target_patterns, + operations, + distributed_operation = None, + quantization_operation = None, + ): + self.source_patterns = source_patterns + self.target_patterns = target_patterns + self.operations = operations + self.distributed_operation = distributed_operation + self.quantization_operation = quantization_operation + + +def _make_legacy_converter(): + return _LegacyConverter(["src.*"], ["tgt.*"], []) + + +def _make_modern_converter(): + return _ModernConverter(["src.*"], ["tgt.*"], []) + + +def _build_that_calls_init(weight_conversions, adapter_name, peft_config = None): + out = [] + for c in weight_conversions or []: + out.append( + c.__class__( + source_patterns = c.source_patterns, + target_patterns = c.target_patterns, + operations = c.operations, + distributed_operation = "dist-x", + quantization_operation = "quant-y", + ) + ) + return out + + +def test_two_arg_call_preserves_upstream_signature(): + twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init}) + patch = _load_patch_function() + patch() + + sig = inspect.signature(twc.build_peft_weight_mapping) + assert "peft_config" in sig.parameters + assert sig.parameters["peft_config"].default is None + + out = twc.build_peft_weight_mapping([_make_legacy_converter()], "default") + assert len(out) == 1 + assert out[0].distributed_operation == "dist-x" + assert out[0].quantization_operation == "quant-y" + + +def test_legacy_init_succeeds_after_patch(): + twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init}) + patch = _load_patch_function() + patch() + + out = twc.build_peft_weight_mapping([_make_legacy_converter()], "default", None) + assert len(out) == 1 + assert out[0].distributed_operation == "dist-x" + assert out[0].quantization_operation == "quant-y" + + +def test_modern_init_not_patched(): + twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init}) + pre_init = _ModernConverter.__init__ + patch = _load_patch_function() + patch() + + twc.build_peft_weight_mapping([_make_modern_converter()], "default", None) + assert _ModernConverter.__init__ is pre_init + + +def test_class_init_restored_after_call(): + twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init}) + pre_init = _LegacyConverter.__init__ + patch = _load_patch_function() + patch() + + twc.build_peft_weight_mapping([_make_legacy_converter()], "default", None) + assert _LegacyConverter.__init__ is pre_init + + +def test_class_init_restored_after_original_build_raises(): + def _raise(weight_conversions, adapter_name, peft_config = None): + raise RuntimeError("simulated PEFT failure") + + twc = _install_fake_peft({"build_peft_weight_mapping": _raise}) + pre_init = _LegacyConverter.__init__ + patch = _load_patch_function() + patch() + + with pytest.raises(RuntimeError): + twc.build_peft_weight_mapping([_make_legacy_converter()], "default", None) + assert _LegacyConverter.__init__ is pre_init + + +def test_partial_patch_restored_when_inspect_signature_raises_mid_loop(): + twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init}) + pre_legacy = _LegacyConverter.__init__ + + class _BadInitConverter: + def __init__(self, source_patterns, target_patterns, operations): + self.source_patterns = source_patterns + self.target_patterns = target_patterns + self.operations = operations + + pre_bad = _BadInitConverter.__init__ + patch = _load_patch_function() + patch() + + real_signature = inspect.signature + + def _fake_signature(callable_): + if callable_ is _BadInitConverter.__init__: + raise ValueError("inspect.signature failed mid-loop") + return real_signature(callable_) + + inspect.signature = _fake_signature + try: + legacy = _LegacyConverter(["src.*"], ["tgt.*"], []) + bad = _BadInitConverter.__new__(_BadInitConverter) + bad.source_patterns = ["src.*"] + bad.target_patterns = ["tgt.*"] + bad.operations = [] + with pytest.raises(ValueError): + twc.build_peft_weight_mapping([legacy, bad], "default", None) + finally: + inspect.signature = real_signature + + assert _LegacyConverter.__init__ is pre_legacy + assert _BadInitConverter.__init__ is pre_bad + + +def test_idempotent_install_does_not_double_wrap(): + twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init}) + patch = _load_patch_function() + patch() + first_wrapped = twc.build_peft_weight_mapping + patch() + assert twc.build_peft_weight_mapping is first_wrapped + + +def test_concurrent_legacy_calls_no_typeerror(): + import time + + def _slow_build(weight_conversions, adapter_name, peft_config = None): + time.sleep(0.05) + return _build_that_calls_init(weight_conversions, adapter_name, peft_config) + + twc = _install_fake_peft({"build_peft_weight_mapping": _slow_build}) + patch = _load_patch_function() + patch() + + errors = [] + results = [] + start = threading.Event() + + def _worker(): + start.wait(timeout = 10) + try: + out = twc.build_peft_weight_mapping( + [_make_legacy_converter()], "default", None + ) + results.append(out) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target = _worker) for _ in range(8)] + for t in threads: + t.start() + start.set() + for t in threads: + t.join(timeout = 15) + + assert errors == [] + assert len(results) == 8 + for out in results: + assert out[0].distributed_operation == "dist-x" + assert out[0].quantization_operation == "quant-y" + assert _LegacyConverter.__init__.__qualname__.startswith("_LegacyConverter") + + +def test_empty_conversions_short_circuits_without_patching(): + twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init}) + pre_init = _LegacyConverter.__init__ + patch = _load_patch_function() + patch() + + out = twc.build_peft_weight_mapping([], "default", None) + assert out == [] + assert _LegacyConverter.__init__ is pre_init diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 52114bb544..9db9ae0a32 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -153,6 +153,7 @@ patch_torchcodec_audio_decoder, disable_torchcodec_if_broken, disable_broken_wandb, + patch_peft_weight_converter_compatibility, ) fix_xformers_performance_issue() @@ -176,6 +177,7 @@ patch_torchcodec_audio_decoder() disable_torchcodec_if_broken() disable_broken_wandb() +patch_peft_weight_converter_compatibility() del fix_xformers_performance_issue del fix_vllm_aimv2_issue @@ -197,6 +199,7 @@ del patch_torchcodec_audio_decoder del disable_torchcodec_if_broken del disable_broken_wandb +del patch_peft_weight_converter_compatibility # Torch 2.4 has including_emulation if DEVICE_TYPE == "cuda": diff --git a/unsloth/import_fixes.py b/unsloth/import_fixes.py index 6cebad7939..f1fba0ce91 100644 --- a/unsloth/import_fixes.py +++ b/unsloth/import_fixes.py @@ -25,6 +25,7 @@ import warnings import sys import functools +import inspect # We cannot do from unsloth_zoo.log import logger since FBGEMM might cause seg faults. UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") in ( @@ -1371,6 +1372,88 @@ def disable_broken_wandb(): os.environ["WANDB_DISABLED"] = "true" +def patch_peft_weight_converter_compatibility(): + """Allow PEFT converter rebuilds on legacy converter constructors.""" + try: + from peft.utils import transformers_weight_conversion as twc + except (ImportError, AttributeError): + return + + if getattr(twc, "_unsloth_weight_converter_compat_patch", False): + return + + import threading + + original_build = twc.build_peft_weight_mapping + patch_lock = threading.RLock() + + def _patch_weight_converter_ctors(weight_conversions, patched): + seen_classes = set() + + for conversion in weight_conversions: + conversion_cls = conversion.__class__ + if conversion_cls in seen_classes: + continue + seen_classes.add(conversion_cls) + + original_init = conversion_cls.__init__ + params = inspect.signature(original_init).parameters + supports_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + supports_distributed = "distributed_operation" in params + supports_quantization = "quantization_operation" in params + if supports_kwargs or (supports_distributed and supports_quantization): + continue + + def _compat_init( + self, + *args, + __original_init = original_init, + __supports_distributed = supports_distributed, + __supports_quantization = supports_quantization, + **kwargs, + ): + unsupported = {} + if not __supports_distributed and "distributed_operation" in kwargs: + unsupported["distributed_operation"] = kwargs.pop( + "distributed_operation" + ) + if not __supports_quantization and "quantization_operation" in kwargs: + unsupported["quantization_operation"] = kwargs.pop( + "quantization_operation" + ) + result = __original_init(self, *args, **kwargs) + for name, value in unsupported.items(): + if hasattr(self, name): + setattr(self, name, value) + return result + + conversion_cls.__init__ = _compat_init + patched.append((conversion_cls, original_init)) + + @functools.wraps(original_build) + def _build_peft_weight_mapping_compat( + weight_conversions, + adapter_name, + peft_config = None, + ): + if not weight_conversions: + return original_build(weight_conversions, adapter_name, peft_config) + + patched_classes = [] + with patch_lock: + try: + _patch_weight_converter_ctors(weight_conversions, patched_classes) + return original_build(weight_conversions, adapter_name, peft_config) + finally: + for conversion_cls, original_init in patched_classes: + conversion_cls.__init__ = original_init + + twc.build_peft_weight_mapping = _build_peft_weight_mapping_compat + twc._unsloth_weight_converter_compat_patch = True + + CAUSAL_CONV1D_BROKEN = False _CAUSAL_CONV1D_PREFIX = "causal_conv1d" _CAUSAL_CONV1D_BLOCKER_SENTINEL = "_unsloth_causal_conv1d_blocker"