Skip to content

Fix gradient checkpointing warning filter implementation#97

Merged
shimmyshimmer merged 1 commit into
unslothai:nightlyfrom
rolandtannous:fix/suppress-gradient-checkpointing-warning
Mar 25, 2025
Merged

Fix gradient checkpointing warning filter implementation#97
shimmyshimmer merged 1 commit into
unslothai:nightlyfrom
rolandtannous:fix/suppress-gradient-checkpointing-warning

Conversation

@rolandtannous
Copy link
Copy Markdown
Contributor

Fix Logger Filter Implementation for Gradient Checkpointing Warnings

Description

This PR fixes a bug in the unsloth_compile_transformers method in compiler.py that causes an AttributeError when trying to suppress gradient checkpointing warnings. The current implementation incorrectly assumes that the model file has a logger attribute, but logger instances are typically module-level variables, not attributes.

Changes

  • Replaced the problematic exec() call that was looking for modeling_file.logger

  • Implemented a proper logging filter that gets the correct logger instance via the transformers logging system

  • Ensures the filter works across different model architectures (Gemma, Mistral, etc.)

Bug Details

The current code attempts to suppress warnings with:

exec("modeling_file.logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals())

This fails with

AttributeError: module 'transformers.models.mistral3.modeling_mistral3' has no attribute 'logger'` because the logger is a module-level variable, not a model attribute.

Solution

The fix gets the appropriate logger directly from the transformers logging module and applies a filter to target only the specific gradient checkpointing warning message.

Testing

Verified the solution works with:

  • Gemma 3 models

  • Mistral 3 models

Related Issues

Fixes:

  • unsloth-zoo issue #90
  • unsloth issue #2146

@rolandtannous
Copy link
Copy Markdown
Contributor Author

rolandtannous commented Mar 25, 2025

This is a bug , resulting in a runtime error and actually preventing most users from even loading models using FastLanguageModel.from_pretrained. Example:

model, tokenizer = FastModel.from_pretrained(
    model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
    max_seq_length = 2048, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

results in the runtime error:

AttributeError: module 'transformers.models.mistral3.modeling_mistral3' has no attribute 'logger'

and prevents most people from progressing forward.

Same thing happens with other models like Gemma3.

SIde note: Oddly enough Transformers implementation of modeling_mistral3 isn't even importing the logging module at all to begin with.. This is not the reason for the runtime error though.

The underlying code in compiler.py throwing the runtime error isn't even trivial. It is just trying to suppress a category of logs by filtering it out to un-clutter the console.

@shimmyshimmer shimmyshimmer changed the base branch from main to nightly March 25, 2025 09:57
@shimmyshimmer shimmyshimmer merged commit 454757c into unslothai:nightly Mar 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants