diff --git a/trl/_compat.py b/trl/_compat.py index 6beeac97a50..80a1cd386d3 100644 --- a/trl/_compat.py +++ b/trl/_compat.py @@ -180,9 +180,35 @@ def _patch_transformers_hybrid_cache() -> None: _is_package_version_below("liger_kernel", "0.6.5") or _is_package_version_below("peft", "0.18.0") ): try: - import transformers.cache_utils as cache_utils + import transformers.cache_utils + from transformers.utils.import_utils import _LazyModule + + Cache = transformers.cache_utils.Cache + + # Patch for liger_kernel: Add HybridCache as an alias for Cache in the cache_utils module + transformers.cache_utils.HybridCache = Cache + + # Patch for peft: Patch _LazyModule.__init__ to add HybridCache to transformers' lazy loading structures + _original_lazy_module_init = _LazyModule.__init__ + + def _patched_lazy_module_init(self, name, *args, **kwargs): + _original_lazy_module_init(self, name, *args, **kwargs) + if name == "transformers": + # Update _LazyModule's internal structures + if hasattr(self, "_import_structure") and "cache_utils" in self._import_structure: + if "HybridCache" not in self._import_structure["cache_utils"]: + self._import_structure["cache_utils"].append("HybridCache") + + if hasattr(self, "_class_to_module"): + self._class_to_module["HybridCache"] = "cache_utils" + + if hasattr(self, "__all__") and "HybridCache" not in self.__all__: + self.__all__.append("HybridCache") + + self.HybridCache = Cache + + _LazyModule.__init__ = _patched_lazy_module_init - cache_utils.HybridCache = cache_utils.Cache except Exception as e: warnings.warn(f"Failed to patch transformers HybridCache compatibility: {e}", stacklevel=2)