diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 876ccb25a4..0b8092eaac 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1702,6 +1702,11 @@ def for_inference(model): lm_head = internal_model.lm_head.weight device_type = lm_head.device.type dtype = model.config.torch_dtype + + if type(dtype) is str: + if dtype == "float16": dtype = torch.float16 + elif dtype == "bfloat16": dtype = torch.bfloat16 + pass # Wrap model.generate model._unwrapped_old_generate = model.generate diff --git a/unsloth/save.py b/unsloth/save.py index d1cd7d6361..d0010321b1 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -183,7 +183,7 @@ def unsloth_save_model( ): if token is None and "HF_TOKEN" in os.environ: token = os.environ["HF_TOKEN"] - + if token is None and "HUGGINGFACE_TOKEN" in os.environ: token = os.environ["HUGGINGFACE_TOKEN"] @@ -489,7 +489,12 @@ def unsloth_save_model( from collections import OrderedDict state_dict = OrderedDict() - torch_dtype = model.config.torch_dtype + torch_dtype = internal_model.config.torch_dtype + if type(torch_dtype) is str: + if torch_dtype == "float16": torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": torch_dtype = torch.bfloat16 + pass + # Check modules to save float32 dtype state_dict["model.embed_tokens.weight"] = internal_model.model.embed_tokens.weight.data.to(torch_dtype)