-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathmerge_adapter.py
37 lines (31 loc) · 1.64 KB
/
merge_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from peft import AutoPeftModelForCausalLM
import torch
import os
import argparse
from transformers import AutoTokenizer
def main():
parser = argparse.ArgumentParser(description="Script to run merge a (Q)LoRA adapter into the base model.")
parser.add_argument('model_name', type=str,
help='The name of the tuned adapter model that you pushed to Huggingface after finetuning or DPO.')
parser.add_argument('output_name', type=str,
help='The name of the output (merged) model. Can either be on Huggingface or on disk')
parser.add_argument('--cpu', action='store_true',
help="Forces usage of CPU. By default GPU is taken if available.")
args = parser.parse_args()
model_name = args.model_name
force_cpu = args.cpu
device = torch.device("cuda:0" if torch.cuda.is_available() and not (force_cpu) else "cpu")
output_model = args.output_name
# Load the model and merge with base
model = AutoPeftModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16, trust_remote_code=True)
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if os.path.isdir(output_model):
model.save_to_disk(output_model)
tokenizer.save_to_disk(output_model)
else:
# Try to push to hub, requires HF_TOKEN environment variable to be set, see https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken
model.push_to_hub(output_model)
tokenizer.push_to_hub(output_model)
if __name__ == "__main__":
main()