diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 95c6e65317..65ba9ccf9b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -156,9 +156,9 @@ def unsloth_base_fast_generate( FastBaseModel.for_inference(self) dtype = _get_dtype(dtype_from_config(self.config)) - # Handle float32 cases - if os.environ.get("UNSLOTH_BFLOAT16_MIXED_PRECISION", "0") == "1": - dtype = torch.bfloat16 + # Handle full float32 cases as config.dtype == torch.float32! + do_bfloat16_mixed_precision = os.environ.get("UNSLOTH_BFLOAT16_MIXED_PRECISION", "0") == "1" + if do_bfloat16_mixed_precision: dtype = torch.bfloat16 # Check if VLM is_vlm = any( @@ -218,7 +218,6 @@ def unsloth_base_fast_generate( dtype = torch.float16 else: autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype) - # Prepare LoRA # state_dict = convert_lora_modules(self, dtype = dtype) @@ -255,6 +254,8 @@ def unsloth_base_fast_generate( cache_implementation = "hybrid" else: cache_implementation = "static" + # [TODO] Unsure why static fails + if do_bfloat16_mixed_precision: cache_implementation = None if "generation_config" in kwargs: kwargs["generation_config"].cache_implementation = cache_implementation @@ -525,7 +526,7 @@ def from_pretrained( f"To enable bfloat16 training to reduce VRAM usage by 50% albeit with a slightly higher loss, do:\n"\ "use `float32_mixed_precision = False` during FastLanguageModel.from_pretrained" ) - os.environ["UNSLOTH_BFLOAT16_MIXED_PRECISION"] = "1" + os.environ["UNSLOTH_BFLOAT16_MIXED_PRECISION"] = "1" else: print("Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.") else: