diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 953991d41d..8b8ec1b1a0 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -584,6 +584,11 @@ def from_pretrained( if transformers_version < Version("4.53.0"): raise RuntimeError("Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST) + elif "falcon-h1" in lowered_model_name: + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \ + "float16;torch.float32;torch.float16;"\ + "if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16);" + os.environ["TRITON_F32_DEFAULT"] = "ieee" else: for check_model_name in DISABLE_COMPILE_MODEL_NAMES: if check_model_name in lowered_model_name: