diff --git a/trl/_compat.py b/trl/_compat.py index fecc9a556ba..04923e34b49 100644 --- a/trl/_compat.py +++ b/trl/_compat.py @@ -150,14 +150,15 @@ def __reduce__(self): def _patch_transformers_hybrid_cache() -> None: """ - Fix HybridCache import compatibility for liger_kernel<=0.6.4. + Fix HybridCache import for transformers v5 compatibility. - - Issue: liger_kernel imports HybridCache from transformers.cache_utils + - Issue: liger_kernel and peft import HybridCache from transformers.cache_utils - HybridCache removed in https://github.com/huggingface/transformers/pull/43168 (transformers>=5.0.0.dev0) - - Fixed in https://github.com/linkedin/Liger-Kernel/pull/1002 (will be released in liger_kernel>=0.6.5) - - This patch can be removed when TRL requires liger_kernel>=0.6.5 + - Fixed in liger_kernel: https://github.com/linkedin/Liger-Kernel/pull/1002 (released in v0.6.5) + - Fixed in peft: https://github.com/huggingface/peft/pull/2735 (released in v0.18.0) + - This can be removed when TRL requires liger_kernel>=0.6.5 and peft>=0.18.0 """ - if _is_package_version_below("liger_kernel", "0.6.5"): + if _is_package_version_below("liger_kernel", "0.6.5") or _is_package_version_below("peft", "0.18.0"): try: import transformers @@ -167,7 +168,7 @@ def _patch_transformers_hybrid_cache() -> None: cache_utils.HybridCache = cache_utils.Cache except Exception as e: - warnings.warn(f"Failed to patch liger_kernel HybridCache compatibility: {e}", stacklevel=2) + warnings.warn(f"Failed to patch transformers HybridCache compatibility: {e}", stacklevel=2) # Apply vLLM patches