From 79af87872a255c94505ae68b0b3b9b8a39967d90 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Mar 2026 07:10:42 +0000 Subject: [PATCH] Fix LoRA forward returning wrong dtype when autocast is disabled The compiler replaces PEFT's inline LoRA computation with an early return through lora_forward(). This skips the original dtype cast back to the base layer dtype, leaving the output in float32 (LoRA weight dtype) instead of e.g. bf16. The mismatch only surfaces when multiple models are loaded in the same process (the patched forward persists globally), causing downstream ops like SDPA to receive float32 queries with bf16 attention masks. Append .to() to the early return so the output dtype always matches the base layer, same as the original PEFT code path. For variants that reassign result to float32 (Linear, GPTQ, LoraParallel) use the saved torch_result_dtype; for variants that only cast x (Linear4bit, Linear8bitLt) use result.dtype directly since result is untouched. --- unsloth_zoo/compiler.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 220e23125..f714c8d49 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -2491,7 +2491,15 @@ def patch_lora_forwards(torch_compile_options): if (old1 not in source and add not in source) and (old2 not in source): pass else: - replace = "return lora_forward(result, lora_A, lora_B, dropout, x, scaling)" + # Linear/GPTQ/LoraParallel reassign result to float32 before the + # loop, so they save the original dtype in torch_result_dtype. + # Linear4bit/Linear8bitLt only cast x, leaving result untouched, + # so result.dtype is still the base-layer dtype at return time. + if re.search(r"\btorch_result_dtype\s*=\s*result\.dtype\b", source): + dtype_cast = "torch_result_dtype" + else: + dtype_cast = "result.dtype" + replace = f"return lora_forward(result, lora_A, lora_B, dropout, x, scaling).to({dtype_cast})" source = source.replace(old1, replace) source = source.replace(old2, replace) pass