Skip to content
Merged
259 changes: 259 additions & 0 deletions tests/test_peft_weight_converter_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import importlib.util
import inspect
import sys
import threading
import types
from pathlib import Path

import pytest


REPO_ROOT = Path(__file__).resolve().parents[1]
IMPORT_FIXES = REPO_ROOT / "unsloth" / "import_fixes.py"


def _load_patch_function():
spec = importlib.util.spec_from_file_location(
"_unsloth_import_fixes_under_test", IMPORT_FIXES
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.patch_peft_weight_converter_compatibility


def _install_fake_peft(twc_namespace):
peft_pkg = types.ModuleType("peft")
peft_pkg.__path__ = []
peft_utils = types.ModuleType("peft.utils")
peft_utils.__path__ = []
twc = types.ModuleType("peft.utils.transformers_weight_conversion")
for k, v in twc_namespace.items():
setattr(twc, k, v)
peft_utils.transformers_weight_conversion = twc
sys.modules["peft"] = peft_pkg
sys.modules["peft.utils"] = peft_utils
sys.modules["peft.utils.transformers_weight_conversion"] = twc
return twc


@pytest.fixture(autouse = True)
def _restore_peft_modules():
saved = {
k: sys.modules.get(k)
for k in (
"peft",
"peft.utils",
"peft.utils.transformers_weight_conversion",
)
}
yield
for k, v in saved.items():
if v is None:
sys.modules.pop(k, None)
else:
sys.modules[k] = v


class _LegacyConverter:
def __init__(self, source_patterns, target_patterns, operations):
self.source_patterns = source_patterns
self.target_patterns = target_patterns
self.operations = operations
self.distributed_operation = None
self.quantization_operation = None


class _ModernConverter:
def __init__(
self,
source_patterns,
target_patterns,
operations,
distributed_operation = None,
quantization_operation = None,
):
self.source_patterns = source_patterns
self.target_patterns = target_patterns
self.operations = operations
self.distributed_operation = distributed_operation
self.quantization_operation = quantization_operation


def _make_legacy_converter():
return _LegacyConverter(["src.*"], ["tgt.*"], [])


def _make_modern_converter():
return _ModernConverter(["src.*"], ["tgt.*"], [])


def _build_that_calls_init(weight_conversions, adapter_name, peft_config = None):
out = []
for c in weight_conversions or []:
out.append(
c.__class__(
source_patterns = c.source_patterns,
target_patterns = c.target_patterns,
operations = c.operations,
distributed_operation = "dist-x",
quantization_operation = "quant-y",
)
)
return out


def test_two_arg_call_preserves_upstream_signature():
twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init})
patch = _load_patch_function()
patch()

sig = inspect.signature(twc.build_peft_weight_mapping)
assert "peft_config" in sig.parameters
assert sig.parameters["peft_config"].default is None

out = twc.build_peft_weight_mapping([_make_legacy_converter()], "default")
assert len(out) == 1
assert out[0].distributed_operation == "dist-x"
assert out[0].quantization_operation == "quant-y"


def test_legacy_init_succeeds_after_patch():
twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init})
patch = _load_patch_function()
patch()

out = twc.build_peft_weight_mapping([_make_legacy_converter()], "default", None)
assert len(out) == 1
assert out[0].distributed_operation == "dist-x"
assert out[0].quantization_operation == "quant-y"


def test_modern_init_not_patched():
twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init})
pre_init = _ModernConverter.__init__
patch = _load_patch_function()
patch()

twc.build_peft_weight_mapping([_make_modern_converter()], "default", None)
assert _ModernConverter.__init__ is pre_init


def test_class_init_restored_after_call():
twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init})
pre_init = _LegacyConverter.__init__
patch = _load_patch_function()
patch()

twc.build_peft_weight_mapping([_make_legacy_converter()], "default", None)
assert _LegacyConverter.__init__ is pre_init


def test_class_init_restored_after_original_build_raises():
def _raise(weight_conversions, adapter_name, peft_config = None):
raise RuntimeError("simulated PEFT failure")

twc = _install_fake_peft({"build_peft_weight_mapping": _raise})
pre_init = _LegacyConverter.__init__
patch = _load_patch_function()
patch()

with pytest.raises(RuntimeError):
twc.build_peft_weight_mapping([_make_legacy_converter()], "default", None)
assert _LegacyConverter.__init__ is pre_init


def test_partial_patch_restored_when_inspect_signature_raises_mid_loop():
twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init})
pre_legacy = _LegacyConverter.__init__

class _BadInitConverter:
def __init__(self, source_patterns, target_patterns, operations):
self.source_patterns = source_patterns
self.target_patterns = target_patterns
self.operations = operations

pre_bad = _BadInitConverter.__init__
patch = _load_patch_function()
patch()

real_signature = inspect.signature

def _fake_signature(callable_):
if callable_ is _BadInitConverter.__init__:
raise ValueError("inspect.signature failed mid-loop")
return real_signature(callable_)

inspect.signature = _fake_signature
try:
legacy = _LegacyConverter(["src.*"], ["tgt.*"], [])
bad = _BadInitConverter.__new__(_BadInitConverter)
bad.source_patterns = ["src.*"]
bad.target_patterns = ["tgt.*"]
bad.operations = []
with pytest.raises(ValueError):
twc.build_peft_weight_mapping([legacy, bad], "default", None)
finally:
inspect.signature = real_signature

assert _LegacyConverter.__init__ is pre_legacy
assert _BadInitConverter.__init__ is pre_bad


def test_idempotent_install_does_not_double_wrap():
twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init})
patch = _load_patch_function()
patch()
first_wrapped = twc.build_peft_weight_mapping
patch()
assert twc.build_peft_weight_mapping is first_wrapped


def test_concurrent_legacy_calls_no_typeerror():
import time

def _slow_build(weight_conversions, adapter_name, peft_config = None):
time.sleep(0.05)
return _build_that_calls_init(weight_conversions, adapter_name, peft_config)

twc = _install_fake_peft({"build_peft_weight_mapping": _slow_build})
patch = _load_patch_function()
patch()

errors = []
results = []
start = threading.Event()

def _worker():
start.wait(timeout = 10)
try:
out = twc.build_peft_weight_mapping(
[_make_legacy_converter()], "default", None
)
results.append(out)
except Exception as e:
errors.append(e)

threads = [threading.Thread(target = _worker) for _ in range(8)]
for t in threads:
t.start()
start.set()
for t in threads:
t.join(timeout = 15)

assert errors == []
assert len(results) == 8
for out in results:
assert out[0].distributed_operation == "dist-x"
assert out[0].quantization_operation == "quant-y"
assert _LegacyConverter.__init__.__qualname__.startswith("_LegacyConverter")


def test_empty_conversions_short_circuits_without_patching():
twc = _install_fake_peft({"build_peft_weight_mapping": _build_that_calls_init})
pre_init = _LegacyConverter.__init__
patch = _load_patch_function()
patch()

out = twc.build_peft_weight_mapping([], "default", None)
assert out == []
assert _LegacyConverter.__init__ is pre_init
3 changes: 3 additions & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
patch_torchcodec_audio_decoder,
disable_torchcodec_if_broken,
disable_broken_wandb,
patch_peft_weight_converter_compatibility,
)

fix_xformers_performance_issue()
Expand All @@ -176,6 +177,7 @@
patch_torchcodec_audio_decoder()
disable_torchcodec_if_broken()
disable_broken_wandb()
patch_peft_weight_converter_compatibility()

del fix_xformers_performance_issue
del fix_vllm_aimv2_issue
Expand All @@ -197,6 +199,7 @@
del patch_torchcodec_audio_decoder
del disable_torchcodec_if_broken
del disable_broken_wandb
del patch_peft_weight_converter_compatibility

# Torch 2.4 has including_emulation
if DEVICE_TYPE == "cuda":
Expand Down
83 changes: 83 additions & 0 deletions unsloth/import_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import warnings
import sys
import functools
import inspect

# We cannot do from unsloth_zoo.log import logger since FBGEMM might cause seg faults.
UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") in (
Expand Down Expand Up @@ -1371,6 +1372,88 @@ def disable_broken_wandb():
os.environ["WANDB_DISABLED"] = "true"


def patch_peft_weight_converter_compatibility():
"""Allow PEFT converter rebuilds on legacy converter constructors."""
try:
from peft.utils import transformers_weight_conversion as twc
except (ImportError, AttributeError):
return

if getattr(twc, "_unsloth_weight_converter_compat_patch", False):
return

import threading

original_build = twc.build_peft_weight_mapping
patch_lock = threading.RLock()

def _patch_weight_converter_ctors(weight_conversions, patched):
seen_classes = set()

for conversion in weight_conversions:
conversion_cls = conversion.__class__
if conversion_cls in seen_classes:
continue
seen_classes.add(conversion_cls)

original_init = conversion_cls.__init__
params = inspect.signature(original_init).parameters
supports_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
)
supports_distributed = "distributed_operation" in params
supports_quantization = "quantization_operation" in params
if supports_kwargs or (supports_distributed and supports_quantization):
continue

def _compat_init(
self,
*args,
__original_init = original_init,
__supports_distributed = supports_distributed,
__supports_quantization = supports_quantization,
**kwargs,
):
unsupported = {}
if not __supports_distributed and "distributed_operation" in kwargs:
unsupported["distributed_operation"] = kwargs.pop(
"distributed_operation"
)
if not __supports_quantization and "quantization_operation" in kwargs:
unsupported["quantization_operation"] = kwargs.pop(
"quantization_operation"
)
result = __original_init(self, *args, **kwargs)
for name, value in unsupported.items():
if hasattr(self, name):
setattr(self, name, value)
return result

conversion_cls.__init__ = _compat_init
patched.append((conversion_cls, original_init))

@functools.wraps(original_build)
def _build_peft_weight_mapping_compat(
weight_conversions,
adapter_name,
peft_config = None,
):
Comment on lines +1436 to +1440

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve optional peft_config in patched wrapper

build_peft_weight_mapping in PEFT accepts peft_config=None, but the monkeypatched replacement makes peft_config mandatory. After this patch runs, any existing call path that relies on the original default (e.g. calling with only weight_conversions and adapter_name) will now raise TypeError, which is a runtime API regression introduced by this commit. Keep the default (and ideally forward *args/**kwargs) so the wrapper remains signature-compatible with the original function.

Useful? React with 👍 / 👎.

Comment on lines +1436 to +1440

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve original weight-mapping wrapper signature

The monkeypatched _build_peft_weight_mapping_compat no longer matches upstream build_peft_weight_mapping call semantics: it requires peft_config instead of keeping the original default (None). After this patch is applied, any call site that previously invoked the function with only weight_conversions and adapter_name will raise TypeError at runtime, which is an API regression in checkpoint conversion paths.

Useful? React with 👍 / 👎.

if not weight_conversions:
return original_build(weight_conversions, adapter_name, peft_config)

patched_classes = []
with patch_lock:
try:
_patch_weight_converter_ctors(weight_conversions, patched_classes)
return original_build(weight_conversions, adapter_name, peft_config)
finally:
for conversion_cls, original_init in patched_classes:
conversion_cls.__init__ = original_init

twc.build_peft_weight_mapping = _build_peft_weight_mapping_compat
twc._unsloth_weight_converter_compat_patch = True


CAUSAL_CONV1D_BROKEN = False
_CAUSAL_CONV1D_PREFIX = "causal_conv1d"
_CAUSAL_CONV1D_BLOCKER_SENTINEL = "_unsloth_causal_conv1d_blocker"
Expand Down