Skip to content

Fix dtype mismatch in fp16 + 4-bit/8-bit LoRA training#4005

Merged
danielhanchen merged 5 commits into
mainfrom
fix/fp16-dtype-mismatch
Feb 9, 2026
Merged

Fix dtype mismatch in fp16 + 4-bit/8-bit LoRA training#4005
danielhanchen merged 5 commits into
mainfrom
fix/fp16-dtype-mismatch

Conversation

@danielhanchen

@danielhanchen danielhanchen commented Feb 9, 2026

Copy link
Copy Markdown
Member

Summary

  • Fix backward pass dtype mismatch in fast_lora.py when training with dtype=torch.float16 and load_in_4bit=True. fast_dequantize() returns tensors in quant_state.dtype (typically bfloat16 or float32), but activations are float16. The subsequent torch.matmul(..., out=X) and addmm_() calls require matching dtypes. Added dtype casts after each fast_dequantize() call in LoRA_MLP.backward and LoRA_QKV.backward (5 locations).
  • Fix GradScaler crash when TRL unconditionally casts trainable parameters to bfloat16 in the peft init block. When fp16=True, GradScaler requires float32 parameters. Made the cast conditional in rl.py -- use float32 when fp16 is enabled, bfloat16 otherwise. This is a no-op for GRPOTrainer (whose peft init block is already removed by the existing regex).

Details

Bug 1: fast_lora.py backward pass

The error manifests as RuntimeError: expected scalar type c10::Half but found c10::BFloat16 during backward pass. This happens because fast_dequantize() returns tensors in the model's original weight dtype from quant_state, which can differ from the compute dtype. For instance, when a bfloat16 model is loaded with dtype=torch.float16, the quant_state retains bfloat16 but activations are float16.

The fix adds if W.dtype != dtype: W = W.to(dtype) after each fast_dequantize() call, matching the pattern already used for LoRA weights (lines 138-145). This is a no-op when dtypes already match (the common bfloat16 case), so there is no performance impact for existing users.

Bug 2: rl.py trainable param dtype cast

TRL's trainer __init__ has a block that casts all trainable params to bfloat16 when the model is loaded in 4-bit or 8-bit. This crashes GradScaler when fp16=True:

_amp_foreach_non_finite_check_and_unscale_cuda not implemented for BFloat16

The fix replaces the unconditional param.data = param.data.to(torch.bfloat16) with a conditional that uses float32 when args.fp16 is set (which is what GradScaler needs). The args variable is available at that point in the compiled source since Unsloth's wrapper sets training args before super().__init__() runs.

Companion PR: unslothai/unsloth-zoo#478

Test plan

  • fp16 + 4-bit training with Llama-3.2-1B-Instruct (15 steps) -- loss 1.5313
  • bf16 + 4-bit training with Llama-3.2-1B-Instruct (15 steps) -- loss 1.5377 (regression check, losses match pre-fix baseline)
  • Gemma3 vision model fp16 training (requires companion unsloth-zoo PR)

danielhanchen and others added 2 commits February 9, 2026 14:44
Two fixes for training with dtype=torch.float16 and load_in_4bit=True:

1. fast_lora.py: fast_dequantize() returns tensors in quant_state.dtype
   (typically bfloat16 or float32), but activations may be float16. The
   subsequent matmul/addmm operations require matching dtypes. Add dtype
   casts after each fast_dequantize() call in LoRA_MLP.backward and
   LoRA_QKV.backward (5 locations total).

2. rl.py: TRL unconditionally casts trainable parameters to bfloat16 in
   the peft init block. When training with fp16=True, this causes
   GradScaler to crash since it requires float32 parameters. Make the
   cast conditional -- use float32 when fp16 is enabled, bfloat16
   otherwise. This is a no-op for GRPOTrainer (whose peft init block is
   already removed by the existing regex), but fixes SFTTrainer and
   other TRL trainers.

Tested with Llama-3.2-1B-Instruct 4-bit on both fp16 and bf16 training.
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves two critical data type (dtype) related issues that impacted the stability and correctness of LoRA training, especially when using mixed-precision (fp16) and 4-bit/8-bit quantization. The first fix ensures that tensors returned by fast_dequantize() in the fast_lora.py backward pass are explicitly cast to the correct compute dtype, preventing RuntimeErrors during torch.matmul and addmm_ operations. The second fix modifies rl.py to prevent GradScaler crashes by conditionally casting trainable parameters to float32 when fp16 training is active, aligning with GradScaler's requirements. These changes enhance the robustness of the training pipeline for advanced quantization and mixed-precision setups.

Highlights

  • Fast LoRA Dtype Mismatch Fix: Corrected a backward pass dtype mismatch in fast_lora.py when training with torch.float16 and 4-bit/8-bit quantization. fast_dequantize() now ensures output tensors match the compute dtype, resolving RuntimeError: expected scalar type c10::Half but found c10::BFloat16.
  • GradScaler Crash Prevention: Addressed a GradScaler crash in rl.py by making the casting of trainable parameters conditional. Parameters are now cast to torch.float32 when fp16 is enabled (as required by GradScaler) and torch.bfloat16 otherwise, instead of an unconditional bfloat16 cast.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/kernels/fast_lora.py
    • Introduced conditional dtype casting (if W.dtype != dtype: W = W.to(dtype)) for upW, gateW in LoRA_MLP.backward and for QW, KW, VW in LoRA_QKV.backward after fast_dequantize calls to ensure dtype consistency with the compute dtype.
  • unsloth/models/rl.py
    • Modified the _patch_trl_rl_trainers function to replace a hardcoded param.data.to(torch.bfloat16) with a conditional expression: param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16). This dynamically sets the parameter dtype based on the fp16 training flag.
Activity
  • The author has identified and provided detailed explanations for two distinct bugs related to dtype mismatches in mixed-precision and quantized LoRA training.
  • A test plan has been provided, confirming successful fp16 + 4-bit and bf16 + 4-bit training with Llama-3.2-1B-Instruct.
  • A future test for Gemma3 vision model fp16 training is planned, contingent on a related unsloth-zoo PR.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves two data type mismatch issues that occur during mixed-precision training with 4/8-bit quantization. The fix in fast_lora.py correctly handles dequantized weight dtypes in the backward pass, and the change in rl.py makes parameter casting conditional to prevent crashes with GradScaler. The changes are clear and well-targeted. I have one suggestion to improve the robustness of the monkey-patching logic in rl.py.

Comment thread unsloth/models/rl.py
Comment on lines +1253 to +1256
RLTrainer_source = RLTrainer_source.replace(
"param.data = param.data.to(torch.bfloat16)",
"param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)",
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using str.replace for monkey-patching can be brittle. It will fail silently if the source string in TRL changes even slightly (e.g., extra whitespace). Using re.sub with a pattern that allows for flexible whitespace would make this patch more robust against future changes in the TRL library.

Suggested change
RLTrainer_source = RLTrainer_source.replace(
"param.data = param.data.to(torch.bfloat16)",
"param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)",
)
RLTrainer_source = re.sub(
r"param\.data\s*=\s*param\.data\.to\(torch\.bfloat16\)",
"param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)",
RLTrainer_source,
)

danielhanchen and others added 3 commits February 9, 2026 15:16
Root cause: fast_dequantize returns tensors in quant_state.dtype, which
for pre-quantized models is bfloat16 (from config.json). The post_patch
methods in llama/gemma/gemma2 call patch_model_and_tokenizer without
passing correct_dtype, so quant_state.dtype is never overridden to match
the user's requested dtype. This causes a dtype mismatch crash in the
backward pass when training with dtype=torch.float16.

Fix: pass the user's dtype from from_pretrained through post_patch to
patch_model_and_tokenizer as correct_dtype, matching the pattern already
used by vision.py.

Revert the 5 symptom-level dtype casts in fast_lora.py (upW, gateW, QW,
KW, VW) since they are no longer needed with quant_state.dtype properly
set at the source.

Tested: fp16+4bit and bf16+4bit Llama-3.2-1B-Instruct 15-step SFT runs
both complete successfully with similar losses (~1.558 vs ~1.563).
TRL 0.26.0+ hardcodes `param.data.to(torch.bfloat16)` for all trainable
params in quantized models, citing the QLoRA paper recommendation. This
is wrong: it ignores the user's requested dtype and breaks GradScaler
when fp16=True. The block exists in sft_trainer, grpo_trainer,
rloo_trainer, and reward_trainer (not dpo_trainer).

Previous fix patched the cast to be dtype-conditional. This commit
replaces the entire guard `if getattr(model, "is_loaded_in_4bit", ...)
or getattr(model, "is_loaded_in_8bit", ...):` with `if False:` to
disable the block entirely. Unsloth already handles adapter dtype via
patch_model_and_tokenizer, making TRL's cast both unnecessary and
harmful.

For GRPOTrainer the enclosing peft init block is already removed by
the regex above, making this a no-op for GRPO.
@danielhanchen danielhanchen merged commit 0960687 into main Feb 9, 2026
4 checks passed
@danielhanchen danielhanchen deleted the fix/fp16-dtype-mismatch branch February 9, 2026 15:39
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
* Fix dtype mismatch in fp16 + 4-bit/8-bit LoRA training

Two fixes for training with dtype=torch.float16 and load_in_4bit=True:

1. fast_lora.py: fast_dequantize() returns tensors in quant_state.dtype
   (typically bfloat16 or float32), but activations may be float16. The
   subsequent matmul/addmm operations require matching dtypes. Add dtype
   casts after each fast_dequantize() call in LoRA_MLP.backward and
   LoRA_QKV.backward (5 locations total).

2. rl.py: TRL unconditionally casts trainable parameters to bfloat16 in
   the peft init block. When training with fp16=True, this causes
   GradScaler to crash since it requires float32 parameters. Make the
   cast conditional -- use float32 when fp16 is enabled, bfloat16
   otherwise. This is a no-op for GRPOTrainer (whose peft init block is
   already removed by the existing regex), but fixes SFTTrainer and
   other TRL trainers.

Tested with Llama-3.2-1B-Instruct 4-bit on both fp16 and bf16 training.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix fp16 + 4-bit LoRA: thread correct_dtype through post_patch

Root cause: fast_dequantize returns tensors in quant_state.dtype, which
for pre-quantized models is bfloat16 (from config.json). The post_patch
methods in llama/gemma/gemma2 call patch_model_and_tokenizer without
passing correct_dtype, so quant_state.dtype is never overridden to match
the user's requested dtype. This causes a dtype mismatch crash in the
backward pass when training with dtype=torch.float16.

Fix: pass the user's dtype from from_pretrained through post_patch to
patch_model_and_tokenizer as correct_dtype, matching the pattern already
used by vision.py.

Revert the 5 symptom-level dtype casts in fast_lora.py (upW, gateW, QW,
KW, VW) since they are no longer needed with quant_state.dtype properly
set at the source.

Tested: fp16+4bit and bf16+4bit Llama-3.2-1B-Instruct 15-step SFT runs
both complete successfully with similar losses (~1.558 vs ~1.563).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove TRL's unconditional bfloat16 cast instead of patching the dtype

TRL 0.26.0+ hardcodes `param.data.to(torch.bfloat16)` for all trainable
params in quantized models, citing the QLoRA paper recommendation. This
is wrong: it ignores the user's requested dtype and breaks GradScaler
when fp16=True. The block exists in sft_trainer, grpo_trainer,
rloo_trainer, and reward_trainer (not dpo_trainer).

Previous fix patched the cast to be dtype-conditional. This commit
replaces the entire guard `if getattr(model, "is_loaded_in_4bit", ...)
or getattr(model, "is_loaded_in_8bit", ...):` with `if False:` to
disable the block entirely. Unsloth already handles adapter dtype via
patch_model_and_tokenizer, making TRL's cast both unnecessary and
harmful.

For GRPOTrainer the enclosing peft init block is already removed by
the regex above, making this a no-op for GRPO.

---------

Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant