diff --git a/generate_lora.py b/generate_lora.py index 6788d53c..b67def31 100644 --- a/generate_lora.py +++ b/generate_lora.py @@ -75,7 +75,7 @@ def main( t0 = time.time() with (lazy_load(pretrained_path) as pretrained_checkpoint, - lazy_load(lora_path) as adapter_checkpoint): + lazy_load(lora_path) as lora_checkpoint): name = llama_model_lookup(pretrained_checkpoint) with EmptyInitOnDevice( @@ -85,8 +85,8 @@ def main( # 1. Load the pretrained weights model.load_state_dict(pretrained_checkpoint, strict=False) - # 2. Load the fine-tuned adapter weights - model.load_state_dict(adapter_checkpoint, strict=False) + # 2. Load the fine-tuned lora weights + model.load_state_dict(lora_checkpoint, strict=False) print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)