Fix Gemma3 QAT training instability with int8-int4 scheme#3818
Conversation
Gemma3 models have a large vocabulary (262144 tokens) which causes training loss to explode when using int8 embedding quantization. This fix auto-detects Gemma3 models and switches from int8-int4 (phone-deployment) to int4 weight-only QAT for stable training.
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a significant training instability issue specific to Gemma3 models when performing quantization-aware training with the "int8-int4" scheme. The change introduces an automatic detection mechanism for Gemma3 models and dynamically adjusts the quantization strategy to "int4" weight-only, thereby ensuring stable and convergent training without the previously observed loss explosion. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for training instability with Gemma3 models when using the int8-int4 QAT scheme. The change automatically detects Gemma3 models and switches to the more stable int4 weight-only QAT scheme, printing an informative message.
The implementation is clear and directly addresses the issue. I have a couple of suggestions to improve robustness and logging consistency:
- Make the model type check case-insensitive to handle potential variations in model names.
- Use
logger.infoinstead ofprintfor the notification message to align with the project's logging standards.
Overall, this is a good, targeted fix.
| # large vocabulary size (262144). Auto-switch to int4 weight-only instead. | ||
| if qat_scheme == "int8-int4": | ||
| model_types = get_transformers_model_type(model.config) | ||
| is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types) |
There was a problem hiding this comment.
For improved robustness, it's a good practice to make the model type check case-insensitive by converting mt to lowercase before the comparison. This will ensure that variations like 'Gemma3' or 'Gemma_3' are also correctly detected.
| is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types) | |
| is_gemma3 = any("gemma3" in mt.lower() or "gemma_3" in mt.lower() for mt in model_types) |
| print( | ||
| "Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. " | ||
| "Switching to int4 weight-only QAT for training stability." | ||
| ) |
There was a problem hiding this comment.
For consistency with the logging practices in the codebase, it's better to use logger.info instead of print for this informative message. This allows users to control log verbosity and redirect output if needed. logger is already imported and used in this file.
| print( | |
| "Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. " | |
| "Switching to int4 weight-only QAT for training stability." | |
| ) | |
| logger.info( | |
| "Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. " | |
| "Switching to int4 weight-only QAT for training stability." | |
| ) |
…lity Fix Gemma3 QAT training instability with int8-int4 scheme
Summary
Gemma3 models have a large vocabulary (262144 tokens) which causes training loss to explode (~17 instead of ~2.7) when using int8 embedding quantization with the
int8-int4(phone-deployment) QAT scheme.This fix auto-detects Gemma3 models using
get_transformers_model_typeand switches toint4weight-only QAT for stable training.Changes
_prepare_model_for_qatfunctionint8-int4toint4scheme for Gemma3 modelsTesting
Verified that training loss converges properly (~2.7) with the fix applied.