From df9dad1849b3e642db92505ee069aa3bfbc056a5 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 14:44:35 +0000 Subject: [PATCH 1/5] Fix dtype mismatch in fp16 + 4-bit/8-bit LoRA training Two fixes for training with dtype=torch.float16 and load_in_4bit=True: 1. fast_lora.py: fast_dequantize() returns tensors in quant_state.dtype (typically bfloat16 or float32), but activations may be float16. The subsequent matmul/addmm operations require matching dtypes. Add dtype casts after each fast_dequantize() call in LoRA_MLP.backward and LoRA_QKV.backward (5 locations total). 2. rl.py: TRL unconditionally casts trainable parameters to bfloat16 in the peft init block. When training with fp16=True, this causes GradScaler to crash since it requires float32 parameters. Make the cast conditional -- use float32 when fp16 is enabled, bfloat16 otherwise. This is a no-op for GRPOTrainer (whose peft init block is already removed by the existing regex), but fixes SFTTrainer and other TRL trainers. Tested with Llama-3.2-1B-Instruct 4-bit on both fp16 and bf16 training. --- unsloth/kernels/fast_lora.py | 5 +++++ unsloth/models/rl.py | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index f1c0e298d9..1292cc03a5 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -191,12 +191,14 @@ def backward(ctx, dY: torch.Tensor): # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) upW = fast_dequantize(upW.t(), upW_quant) + if upW.dtype != dtype: upW = upW.to(dtype) dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None) del upW # dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) dX.addmm_(df @ upB.t(), upA.t(), alpha = upS) gateW = fast_dequantize(gateW.t(), gateW_quant) + if gateW.dtype != dtype: gateW = gateW.to(dtype) # dX += de @ gateW.t() dX.addmm_(de, gateW.t()) del gateW @@ -487,6 +489,7 @@ def backward(ctx, dQ, dK, dV): # Combine derivatives to find dX # dQ QW = fast_dequantize(QW.t(), QW_quant) + if QW.dtype != dtype: QW = QW.to(dtype) dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None) del QW # dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) @@ -494,6 +497,7 @@ def backward(ctx, dQ, dK, dV): # dK KW = fast_dequantize(KW.t(), KW_quant) + if KW.dtype != dtype: KW = KW.to(dtype) # dX += dK @ KW.t() dX.addmm_(dK, KW.t()) del KW @@ -502,6 +506,7 @@ def backward(ctx, dQ, dK, dV): # dV VW = fast_dequantize(VW.t(), VW_quant) + if VW.dtype != dtype: VW = VW.to(dtype) # dX += dV @ VW.t() dX.addmm_(dV, VW.t()) del VW diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 91f071b1ea..9a86db17be 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1244,6 +1244,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): flags = re.DOTALL, ) + # Fix fp16 + 4-bit/8-bit training: TRL unconditionally casts trainable params + # to bfloat16, which crashes GradScaler when training in float16. Make the + # cast respect the training dtype: float32 for fp16, bfloat16 otherwise. + # For GRPOTrainer the entire peft init block is already removed above, so + # this replace is a no-op for GRPO. For SFT and other trainers it makes + # the cast conditional. + RLTrainer_source = RLTrainer_source.replace( + "param.data = param.data.to(torch.bfloat16)", + "param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)", + ) + if RLTrainer_name == "SFTTrainer": original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]' new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]' From 5aadfcdfeefc630a501eb0cb7363830f771c8a2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 14:45:37 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/kernels/fast_lora.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index 1292cc03a5..47776c30a8 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -191,14 +191,16 @@ def backward(ctx, dY: torch.Tensor): # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) upW = fast_dequantize(upW.t(), upW_quant) - if upW.dtype != dtype: upW = upW.to(dtype) + if upW.dtype != dtype: + upW = upW.to(dtype) dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None) del upW # dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) dX.addmm_(df @ upB.t(), upA.t(), alpha = upS) gateW = fast_dequantize(gateW.t(), gateW_quant) - if gateW.dtype != dtype: gateW = gateW.to(dtype) + if gateW.dtype != dtype: + gateW = gateW.to(dtype) # dX += de @ gateW.t() dX.addmm_(de, gateW.t()) del gateW @@ -489,7 +491,8 @@ def backward(ctx, dQ, dK, dV): # Combine derivatives to find dX # dQ QW = fast_dequantize(QW.t(), QW_quant) - if QW.dtype != dtype: QW = QW.to(dtype) + if QW.dtype != dtype: + QW = QW.to(dtype) dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None) del QW # dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) @@ -497,7 +500,8 @@ def backward(ctx, dQ, dK, dV): # dK KW = fast_dequantize(KW.t(), KW_quant) - if KW.dtype != dtype: KW = KW.to(dtype) + if KW.dtype != dtype: + KW = KW.to(dtype) # dX += dK @ KW.t() dX.addmm_(dK, KW.t()) del KW @@ -506,7 +510,8 @@ def backward(ctx, dQ, dK, dV): # dV VW = fast_dequantize(VW.t(), VW_quant) - if VW.dtype != dtype: VW = VW.to(dtype) + if VW.dtype != dtype: + VW = VW.to(dtype) # dX += dV @ VW.t() dX.addmm_(dV, VW.t()) del VW From 60da2b59080259759f8c7c8c5803f0d882f43015 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 15:16:05 +0000 Subject: [PATCH 3/5] Fix fp16 + 4-bit LoRA: thread correct_dtype through post_patch Root cause: fast_dequantize returns tensors in quant_state.dtype, which for pre-quantized models is bfloat16 (from config.json). The post_patch methods in llama/gemma/gemma2 call patch_model_and_tokenizer without passing correct_dtype, so quant_state.dtype is never overridden to match the user's requested dtype. This causes a dtype mismatch crash in the backward pass when training with dtype=torch.float16. Fix: pass the user's dtype from from_pretrained through post_patch to patch_model_and_tokenizer as correct_dtype, matching the pattern already used by vision.py. Revert the 5 symptom-level dtype casts in fast_lora.py (upW, gateW, QW, KW, VW) since they are no longer needed with quant_state.dtype properly set at the source. Tested: fp16+4bit and bf16+4bit Llama-3.2-1B-Instruct 15-step SFT runs both complete successfully with similar losses (~1.558 vs ~1.563). --- unsloth/kernels/fast_lora.py | 10 ---------- unsloth/models/gemma.py | 4 ++-- unsloth/models/gemma2.py | 4 ++-- unsloth/models/granite.py | 2 +- unsloth/models/llama.py | 6 +++--- 5 files changed, 8 insertions(+), 18 deletions(-) diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index 47776c30a8..f1c0e298d9 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -191,16 +191,12 @@ def backward(ctx, dY: torch.Tensor): # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) upW = fast_dequantize(upW.t(), upW_quant) - if upW.dtype != dtype: - upW = upW.to(dtype) dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None) del upW # dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) dX.addmm_(df @ upB.t(), upA.t(), alpha = upS) gateW = fast_dequantize(gateW.t(), gateW_quant) - if gateW.dtype != dtype: - gateW = gateW.to(dtype) # dX += de @ gateW.t() dX.addmm_(de, gateW.t()) del gateW @@ -491,8 +487,6 @@ def backward(ctx, dQ, dK, dV): # Combine derivatives to find dX # dQ QW = fast_dequantize(QW.t(), QW_quant) - if QW.dtype != dtype: - QW = QW.to(dtype) dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None) del QW # dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) @@ -500,8 +494,6 @@ def backward(ctx, dQ, dK, dV): # dK KW = fast_dequantize(KW.t(), KW_quant) - if KW.dtype != dtype: - KW = KW.to(dtype) # dX += dK @ KW.t() dX.addmm_(dK, KW.t()) del KW @@ -510,8 +502,6 @@ def backward(ctx, dQ, dK, dV): # dV VW = fast_dequantize(VW.t(), VW_quant) - if VW.dtype != dtype: - VW = VW.to(dtype) # dX += dV @ VW.t() dX.addmm_(dV, VW.t()) del VW diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 7173c03495..55a8c8697f 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -442,10 +442,10 @@ def pre_patch(): return @staticmethod - def post_patch(model, tokenizer): + def post_patch(model, tokenizer, correct_dtype = None): # Gemma does not downcast RoPE model, tokenizer = patch_model_and_tokenizer( - model, tokenizer, downcast_rope = False + model, tokenizer, downcast_rope = False, correct_dtype = correct_dtype ) # Add 1 to weight diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 16d04955d3..03e77f6504 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -613,10 +613,10 @@ def pre_patch(): return @staticmethod - def post_patch(model, tokenizer): + def post_patch(model, tokenizer, correct_dtype = None): # Gemma does not downcast RoPE model, tokenizer = patch_model_and_tokenizer( - model, tokenizer, downcast_rope = False + model, tokenizer, downcast_rope = False, correct_dtype = correct_dtype ) # Add 1 to weight diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index aae746aed1..168df90f4c 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -542,7 +542,7 @@ def pre_patch(): return @staticmethod - def post_patch(model, tokenizer): + def post_patch(model, tokenizer, correct_dtype = None): # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 model.model.embed_tokens = torch.nn.Embedding.from_pretrained( diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 61771e4567..bf8a528e5f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2483,7 +2483,7 @@ def from_pretrained( ) model, tokenizer = patch_tokenizer(model, tokenizer) - model, tokenizer = model_patcher.post_patch(model, tokenizer) + model, tokenizer = model_patcher.post_patch(model, tokenizer, correct_dtype = dtype) # Patch up QKV / O and MLP for idx, layer in enumerate(model.model.layers): @@ -2666,9 +2666,9 @@ def from_pretrained( return model, tokenizer @staticmethod - def post_patch(model, tokenizer): + def post_patch(model, tokenizer, correct_dtype = None): model, tokenizer = patch_model_and_tokenizer( - model, tokenizer, downcast_rope = True + model, tokenizer, downcast_rope = True, correct_dtype = correct_dtype ) return model, tokenizer From f0f4d40f4c820aeb77b55469599b96561b0c3de9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 15:17:14 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index bf8a528e5f..c1e9110759 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2483,7 +2483,9 @@ def from_pretrained( ) model, tokenizer = patch_tokenizer(model, tokenizer) - model, tokenizer = model_patcher.post_patch(model, tokenizer, correct_dtype = dtype) + model, tokenizer = model_patcher.post_patch( + model, tokenizer, correct_dtype = dtype + ) # Patch up QKV / O and MLP for idx, layer in enumerate(model.model.layers): From 5f4f5aeed4d849cbfeec9671efae8206074adfa4 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 15:36:53 +0000 Subject: [PATCH 5/5] Remove TRL's unconditional bfloat16 cast instead of patching the dtype TRL 0.26.0+ hardcodes `param.data.to(torch.bfloat16)` for all trainable params in quantized models, citing the QLoRA paper recommendation. This is wrong: it ignores the user's requested dtype and breaks GradScaler when fp16=True. The block exists in sft_trainer, grpo_trainer, rloo_trainer, and reward_trainer (not dpo_trainer). Previous fix patched the cast to be dtype-conditional. This commit replaces the entire guard `if getattr(model, "is_loaded_in_4bit", ...) or getattr(model, "is_loaded_in_8bit", ...):` with `if False:` to disable the block entirely. Unsloth already handles adapter dtype via patch_model_and_tokenizer, making TRL's cast both unnecessary and harmful. For GRPOTrainer the enclosing peft init block is already removed by the regex above, making this a no-op for GRPO. --- unsloth/models/rl.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 9a86db17be..e22be55c7a 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1244,15 +1244,16 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): flags = re.DOTALL, ) - # Fix fp16 + 4-bit/8-bit training: TRL unconditionally casts trainable params - # to bfloat16, which crashes GradScaler when training in float16. Make the - # cast respect the training dtype: float32 for fp16, bfloat16 otherwise. - # For GRPOTrainer the entire peft init block is already removed above, so - # this replace is a no-op for GRPO. For SFT and other trainers it makes - # the cast conditional. + # Remove TRL's unconditional bfloat16 cast of trainable params (added in + # TRL 0.26.0). TRL hardcodes bfloat16 for QLoRA per the original paper's + # recommendation, but this is wrong: it ignores the user's requested dtype + # and breaks GradScaler when training with fp16=True. Unsloth already + # handles adapter dtype correctly via patch_model_and_tokenizer, so the + # entire block is unnecessary. For GRPOTrainer the enclosing peft init + # block is already removed above, making this a no-op for GRPO. RLTrainer_source = RLTrainer_source.replace( - "param.data = param.data.to(torch.bfloat16)", - "param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)", + 'if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):', + "if False:", ) if RLTrainer_name == "SFTTrainer":