diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 664cc741f..98c2a701b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -2395,6 +2395,8 @@ def patch_lora_forwards(torch_compile_options): # When autocast is disabled and base_layer has float32 weights, # cast x to match the weight dtype to prevent dtype mismatch. # For 8-bit layers, the base_layer call was already replaced above. + # For 4-bit layers, weight.dtype is uint8 (packed quantized bytes), + # so we must skip the cast to avoid corrupting input values. _base_layer_call = "result = self.base_layer(x, *args, **kwargs)" _m = re.search(r'^( *)' + re.escape(_base_layer_call), source, re.MULTILINE) if _m: @@ -2403,6 +2405,7 @@ def patch_lora_forwards(torch_compile_options): _base_layer_call, f"if not torch.is_autocast_enabled() and hasattr(self.base_layer, 'weight') " f"and self.base_layer.weight is not None " + f"and not hasattr(self.base_layer.weight, 'quant_state') " f"and x.dtype != self.base_layer.weight.dtype:\n" f"{_ind} x = x.to(self.base_layer.weight.dtype)\n" f"{_ind}{_base_layer_call}",