From a62e4c6343fd3dbb86b38163e09837b86897ba47 Mon Sep 17 00:00:00 2001 From: Roland Tannous Date: Mon, 24 Mar 2025 15:39:00 +0000 Subject: [PATCH] Fix gradient checkpointing warning filter implementation --- unsloth_zoo/compiler.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 9717a96f4..3388c9c05 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1000,7 +1000,7 @@ def apply_fused_lm_head(forward): cross_entropy_replacement = cross_entropy_replacement\ .replace( - "$KWARGS$", + "$KWARGS$", "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})" ) @@ -1179,7 +1179,7 @@ def patch_gradient_checkpointing(module, source): .replace("LAYER", layer).replace("MODULELIST_ITEM", modulelist_item)\ .replace("ARGS", args).replace("$", spaces) forward = forward.replace(forward[span[0] : span[1]], replacer) - + # Also fix init spaces = init.find("def") init = init + "\n" + (spaces + 4) * " " + "self.gradient_checkpointing = False\n\n" @@ -1381,10 +1381,10 @@ def patch_gradient_accumulation(modeling_file, module): functions = dir(modeling_file) module = eval(f"modeling_file.{module}") - try: + try: forward = module.forward source = inspect.getsource(forward) - except: + except: return None has_kwargs = tuple(inspect.signature(forward).parameters.values())[-1].kind == inspect._VAR_KEYWORD if has_kwargs: return None @@ -1450,6 +1450,10 @@ def unsloth_compile_transformers( disable : bool = False, return_logits : bool = False, ): + # import transformers logging module and instantiate model_type logging instance. + from transformers import logging as transformers_logging + model_logger = transformers_logging.get_logger(f"modeling_{model_type}") + # All Unsloth Zoo code licensed under LGPLv3 disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") if fast_residual_stream: @@ -1461,8 +1465,9 @@ def unsloth_compile_transformers( modeling_file = eval(model_location) if hasattr(modeling_file, "__UNSLOTH_PATCHED__"): return - # Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False` - exec("modeling_file.logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals()) + # Use transformers model_type logger to supress message: Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False` + exec("model_logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals()) + # torch_compile_options UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" @@ -1792,7 +1797,7 @@ def unsloth_compile_transformers( # Disable if torch < 2.5 or V100s 7.0 (Tesla T4 7.5 works) or old Triton < 3 if OLD_CUDA_ARCH_VERSION or OLD_TORCH_VERSION or OLD_TRITON_VERSION: continue - + module_class = eval(f"modeling_file.{module}") if hasattr(module_class, "forward") and issubclass(module_class, GenerationMixin): try: