Skip to content

Fix Gemma3 QAT training instability with int8-int4 scheme#3818

Merged
danielhanchen merged 1 commit into
mainfrom
fix-gemma3-qat-stability
Jan 2, 2026
Merged

Fix Gemma3 QAT training instability with int8-int4 scheme#3818
danielhanchen merged 1 commit into
mainfrom
fix-gemma3-qat-stability

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

Gemma3 models have a large vocabulary (262144 tokens) which causes training loss to explode (~17 instead of ~2.7) when using int8 embedding quantization with the int8-int4 (phone-deployment) QAT scheme.

This fix auto-detects Gemma3 models using get_transformers_model_type and switches to int4 weight-only QAT for stable training.

Changes

  • Added Gemma3 detection in _prepare_model_for_qat function
  • Auto-switches from int8-int4 to int4 scheme for Gemma3 models
  • Prints informative message when switching

Testing

Verified that training loss converges properly (~2.7) with the fix applied.

Gemma3 models have a large vocabulary (262144 tokens) which causes
training loss to explode when using int8 embedding quantization.

This fix auto-detects Gemma3 models and switches from int8-int4
(phone-deployment) to int4 weight-only QAT for stable training.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a significant training instability issue specific to Gemma3 models when performing quantization-aware training with the "int8-int4" scheme. The change introduces an automatic detection mechanism for Gemma3 models and dynamically adjusts the quantization strategy to "int4" weight-only, thereby ensuring stable and convergent training without the previously observed loss explosion.

Highlights

  • Gemma3 QAT Instability Fix: Addresses training instability in Gemma3 models when using the "int8-int4" QAT scheme, which was causing loss explosion due to their large vocabulary.
  • Automatic Scheme Switch: Implements automatic detection of Gemma3 models and switches the QAT scheme from "int8-int4" to "int4" weight-only to ensure stable training.
  • Informative Messaging: Adds a print statement to notify users when the QAT scheme is automatically adjusted for Gemma3 models.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for training instability with Gemma3 models when using the int8-int4 QAT scheme. The change automatically detects Gemma3 models and switches to the more stable int4 weight-only QAT scheme, printing an informative message.

The implementation is clear and directly addresses the issue. I have a couple of suggestions to improve robustness and logging consistency:

  1. Make the model type check case-insensitive to handle potential variations in model names.
  2. Use logger.info instead of print for the notification message to align with the project's logging standards.

Overall, this is a good, targeted fix.

Comment thread unsloth/models/_utils.py
# large vocabulary size (262144). Auto-switch to int4 weight-only instead.
if qat_scheme == "int8-int4":
model_types = get_transformers_model_type(model.config)
is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved robustness, it's a good practice to make the model type check case-insensitive by converting mt to lowercase before the comparison. This will ensure that variations like 'Gemma3' or 'Gemma_3' are also correctly detected.

Suggested change
is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types)
is_gemma3 = any("gemma3" in mt.lower() or "gemma_3" in mt.lower() for mt in model_types)

Comment thread unsloth/models/_utils.py
Comment on lines +2207 to +2210
print(
"Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
"Switching to int4 weight-only QAT for training stability."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the logging practices in the codebase, it's better to use logger.info instead of print for this informative message. This allows users to control log verbosity and redirect output if needed. logger is already imported and used in this file.

Suggested change
print(
"Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
"Switching to int4 weight-only QAT for training stability."
)
logger.info(
"Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
"Switching to int4 weight-only QAT for training stability."
)

@danielhanchen danielhanchen merged commit 62907d4 into main Jan 2, 2026
4 checks passed
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…lity

Fix Gemma3 QAT training instability with int8-int4 scheme
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant