Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,18 @@ def _prepare_model_for_qat(
from torchao.quantization.granularity import PerGroup, PerAxis
from torchao.quantization.qat import QATConfig

# Gemma3 models have issues with int8 embedding quantization due to their
# large vocabulary size (262144). Auto-switch to int4 weight-only instead.
if qat_scheme == "int8-int4":
model_types = get_transformers_model_type(model.config)
is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved robustness, it's a good practice to make the model type check case-insensitive by converting mt to lowercase before the comparison. This will ensure that variations like 'Gemma3' or 'Gemma_3' are also correctly detected.

Suggested change
is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types)
is_gemma3 = any("gemma3" in mt.lower() or "gemma_3" in mt.lower() for mt in model_types)

if is_gemma3:
print(
"Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
"Switching to int4 weight-only QAT for training stability."
)
Comment on lines +2207 to +2210
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the logging practices in the codebase, it's better to use logger.info instead of print for this informative message. This allows users to control log verbosity and redirect output if needed. logger is already imported and used in this file.

Suggested change
print(
"Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
"Switching to int4 weight-only QAT for training stability."
)
logger.info(
"Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
"Switching to int4 weight-only QAT for training stability."
)

qat_scheme = "int4"

if not isinstance(qat_scheme, TorchAOConfig):
torchao_config: Optional[TorchAOConfig] = None
if qat_scheme == "fp8-int4":
Expand Down