Fix dtype mismatch in fp16 + 4-bit/8-bit LoRA training#4005
Conversation
Two fixes for training with dtype=torch.float16 and load_in_4bit=True: 1. fast_lora.py: fast_dequantize() returns tensors in quant_state.dtype (typically bfloat16 or float32), but activations may be float16. The subsequent matmul/addmm operations require matching dtypes. Add dtype casts after each fast_dequantize() call in LoRA_MLP.backward and LoRA_QKV.backward (5 locations total). 2. rl.py: TRL unconditionally casts trainable parameters to bfloat16 in the peft init block. When training with fp16=True, this causes GradScaler to crash since it requires float32 parameters. Make the cast conditional -- use float32 when fp16 is enabled, bfloat16 otherwise. This is a no-op for GRPOTrainer (whose peft init block is already removed by the existing regex), but fixes SFTTrainer and other TRL trainers. Tested with Llama-3.2-1B-Instruct 4-bit on both fp16 and bf16 training.
for more information, see https://pre-commit.ci
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves two critical data type (dtype) related issues that impacted the stability and correctness of LoRA training, especially when using mixed-precision ( Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request effectively resolves two data type mismatch issues that occur during mixed-precision training with 4/8-bit quantization. The fix in fast_lora.py correctly handles dequantized weight dtypes in the backward pass, and the change in rl.py makes parameter casting conditional to prevent crashes with GradScaler. The changes are clear and well-targeted. I have one suggestion to improve the robustness of the monkey-patching logic in rl.py.
| RLTrainer_source = RLTrainer_source.replace( | ||
| "param.data = param.data.to(torch.bfloat16)", | ||
| "param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)", | ||
| ) |
There was a problem hiding this comment.
Using str.replace for monkey-patching can be brittle. It will fail silently if the source string in TRL changes even slightly (e.g., extra whitespace). Using re.sub with a pattern that allows for flexible whitespace would make this patch more robust against future changes in the TRL library.
| RLTrainer_source = RLTrainer_source.replace( | |
| "param.data = param.data.to(torch.bfloat16)", | |
| "param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)", | |
| ) | |
| RLTrainer_source = re.sub( | |
| r"param\.data\s*=\s*param\.data\.to\(torch\.bfloat16\)", | |
| "param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)", | |
| RLTrainer_source, | |
| ) |
Root cause: fast_dequantize returns tensors in quant_state.dtype, which for pre-quantized models is bfloat16 (from config.json). The post_patch methods in llama/gemma/gemma2 call patch_model_and_tokenizer without passing correct_dtype, so quant_state.dtype is never overridden to match the user's requested dtype. This causes a dtype mismatch crash in the backward pass when training with dtype=torch.float16. Fix: pass the user's dtype from from_pretrained through post_patch to patch_model_and_tokenizer as correct_dtype, matching the pattern already used by vision.py. Revert the 5 symptom-level dtype casts in fast_lora.py (upW, gateW, QW, KW, VW) since they are no longer needed with quant_state.dtype properly set at the source. Tested: fp16+4bit and bf16+4bit Llama-3.2-1B-Instruct 15-step SFT runs both complete successfully with similar losses (~1.558 vs ~1.563).
for more information, see https://pre-commit.ci
TRL 0.26.0+ hardcodes `param.data.to(torch.bfloat16)` for all trainable params in quantized models, citing the QLoRA paper recommendation. This is wrong: it ignores the user's requested dtype and breaks GradScaler when fp16=True. The block exists in sft_trainer, grpo_trainer, rloo_trainer, and reward_trainer (not dpo_trainer). Previous fix patched the cast to be dtype-conditional. This commit replaces the entire guard `if getattr(model, "is_loaded_in_4bit", ...) or getattr(model, "is_loaded_in_8bit", ...):` with `if False:` to disable the block entirely. Unsloth already handles adapter dtype via patch_model_and_tokenizer, making TRL's cast both unnecessary and harmful. For GRPOTrainer the enclosing peft init block is already removed by the regex above, making this a no-op for GRPO.
* Fix dtype mismatch in fp16 + 4-bit/8-bit LoRA training Two fixes for training with dtype=torch.float16 and load_in_4bit=True: 1. fast_lora.py: fast_dequantize() returns tensors in quant_state.dtype (typically bfloat16 or float32), but activations may be float16. The subsequent matmul/addmm operations require matching dtypes. Add dtype casts after each fast_dequantize() call in LoRA_MLP.backward and LoRA_QKV.backward (5 locations total). 2. rl.py: TRL unconditionally casts trainable parameters to bfloat16 in the peft init block. When training with fp16=True, this causes GradScaler to crash since it requires float32 parameters. Make the cast conditional -- use float32 when fp16 is enabled, bfloat16 otherwise. This is a no-op for GRPOTrainer (whose peft init block is already removed by the existing regex), but fixes SFTTrainer and other TRL trainers. Tested with Llama-3.2-1B-Instruct 4-bit on both fp16 and bf16 training. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix fp16 + 4-bit LoRA: thread correct_dtype through post_patch Root cause: fast_dequantize returns tensors in quant_state.dtype, which for pre-quantized models is bfloat16 (from config.json). The post_patch methods in llama/gemma/gemma2 call patch_model_and_tokenizer without passing correct_dtype, so quant_state.dtype is never overridden to match the user's requested dtype. This causes a dtype mismatch crash in the backward pass when training with dtype=torch.float16. Fix: pass the user's dtype from from_pretrained through post_patch to patch_model_and_tokenizer as correct_dtype, matching the pattern already used by vision.py. Revert the 5 symptom-level dtype casts in fast_lora.py (upW, gateW, QW, KW, VW) since they are no longer needed with quant_state.dtype properly set at the source. Tested: fp16+4bit and bf16+4bit Llama-3.2-1B-Instruct 15-step SFT runs both complete successfully with similar losses (~1.558 vs ~1.563). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove TRL's unconditional bfloat16 cast instead of patching the dtype TRL 0.26.0+ hardcodes `param.data.to(torch.bfloat16)` for all trainable params in quantized models, citing the QLoRA paper recommendation. This is wrong: it ignores the user's requested dtype and breaks GradScaler when fp16=True. The block exists in sft_trainer, grpo_trainer, rloo_trainer, and reward_trainer (not dpo_trainer). Previous fix patched the cast to be dtype-conditional. This commit replaces the entire guard `if getattr(model, "is_loaded_in_4bit", ...) or getattr(model, "is_loaded_in_8bit", ...):` with `if False:` to disable the block entirely. Unsloth already handles adapter dtype via patch_model_and_tokenizer, making TRL's cast both unnecessary and harmful. For GRPOTrainer the enclosing peft init block is already removed by the regex above, making this a no-op for GRPO. --------- Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Summary
fast_lora.pywhen training withdtype=torch.float16andload_in_4bit=True.fast_dequantize()returns tensors inquant_state.dtype(typically bfloat16 or float32), but activations are float16. The subsequenttorch.matmul(..., out=X)andaddmm_()calls require matching dtypes. Added dtype casts after eachfast_dequantize()call inLoRA_MLP.backwardandLoRA_QKV.backward(5 locations).fp16=True, GradScaler requires float32 parameters. Made the cast conditional inrl.py-- use float32 when fp16 is enabled, bfloat16 otherwise. This is a no-op for GRPOTrainer (whose peft init block is already removed by the existing regex).Details
Bug 1: fast_lora.py backward pass
The error manifests as
RuntimeError: expected scalar type c10::Half but found c10::BFloat16during backward pass. This happens becausefast_dequantize()returns tensors in the model's original weight dtype fromquant_state, which can differ from the compute dtype. For instance, when a bfloat16 model is loaded withdtype=torch.float16, the quant_state retains bfloat16 but activations are float16.The fix adds
if W.dtype != dtype: W = W.to(dtype)after eachfast_dequantize()call, matching the pattern already used for LoRA weights (lines 138-145). This is a no-op when dtypes already match (the common bfloat16 case), so there is no performance impact for existing users.Bug 2: rl.py trainable param dtype cast
TRL's trainer
__init__has a block that casts all trainable params to bfloat16 when the model is loaded in 4-bit or 8-bit. This crashes GradScaler whenfp16=True:The fix replaces the unconditional
param.data = param.data.to(torch.bfloat16)with a conditional that uses float32 whenargs.fp16is set (which is what GradScaler needs). Theargsvariable is available at that point in the compiled source since Unsloth's wrapper sets training args beforesuper().__init__()runs.Companion PR: unslothai/unsloth-zoo#478
Test plan