-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFine-tune Llama3 with ORPO
1 lines (1 loc) · 14.5 KB
/
Fine-tune Llama3 with ORPO
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30699,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Overview\n\nORPO is a new exciting fine-tuning technique that combines that traditional supervised fine-tuning and preference alignement stagaes into a single process. This reduces the computational resources and time required for training. Moreover, empirical results demonstrate that ORPO outperforms other alignment methods on various model size and benchmarks.\n\nWe will fine-tune the Llama 3 model 8B model using ORPO with the TRL library.","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19"}},{"cell_type":"code","source":"!pip install -U -q transformers==4.39.3\n!pip install -U -q accelerate==0.28.0\n!pip install -U -q datasets==2.18.0\n!pip install -U -q peft==0.10.0\n!pip install -U -q bitsandbytes==0.43.1\n!pip install -U -q trl==0.8.6","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:32:41.033793Z","iopub.execute_input":"2024-06-24T07:32:41.034191Z","iopub.status.idle":"2024-06-24T07:33:57.905615Z","shell.execute_reply.started":"2024-06-24T07:32:41.03416Z","shell.execute_reply":"2024-06-24T07:33:57.904333Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"### Note: If your env suports flash attention, be sure installed it.","metadata":{}},{"cell_type":"code","source":"import torch\n\nif torch.cuda.get_device_capability()[0] >= 8:\n !pip install -qqq flash-attn\n attn_implementation = \"flash_attention_2\"\n torch_dtype = torch.bfloat16\nelse:\n attn_implementation = \"eager\"\n torch_dtype = torch.float16","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:33:57.907731Z","iopub.execute_input":"2024-06-24T07:33:57.90806Z","iopub.status.idle":"2024-06-24T07:33:59.77834Z","shell.execute_reply.started":"2024-06-24T07:33:57.90803Z","shell.execute_reply":"2024-06-24T07:33:59.777462Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import os\nfrom huggingface_hub import login\nfrom kaggle_secrets import UserSecretsClient\nuser_secrets = UserSecretsClient()\nlogin(token=user_secrets.get_secret(\"HUGGINGFACE_TOKEN\"))\n\nos.environ[\"WANDB_API_KEY\"]=user_secrets.get_secret(\"WANDB_API_KEY\")\nos.environ[\"WANDB_PROJECT\"] = \"Fine-tuning Llama 3 8B\"\nos.environ[\"WANDB_NAME\"] = \"ft-Llama3-8b-orpo\"\nos.environ[\"MODEL_NAME\"] = \"meta-llama/Meta-Llama-3-8B\"\nos.environ[\"DATASET\"] = \"mlabonne/orpo-dpo-mix-40k\"\n\ntorch.backends.cudnn.deterministic=True\n# https://github.com/huggingface/transformers/issues/28731\ntorch.backends.cuda.enable_mem_efficient_sdp(False)","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:33:59.779605Z","iopub.execute_input":"2024-06-24T07:33:59.78011Z","iopub.status.idle":"2024-06-24T07:34:01.954568Z","shell.execute_reply.started":"2024-06-24T07:33:59.780074Z","shell.execute_reply":"2024-06-24T07:34:01.953749Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"!accelerate estimate-memory ${MODEL_NAME} --library_name transformers","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:34:01.95585Z","iopub.execute_input":"2024-06-24T07:34:01.956228Z","iopub.status.idle":"2024-06-24T07:34:09.978745Z","shell.execute_reply.started":"2024-06-24T07:34:01.956192Z","shell.execute_reply":"2024-06-24T07:34:09.977692Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Quantization with QLoRA","metadata":{}},{"cell_type":"code","source":"from transformers import BitsAndBytesConfig\nfrom peft import LoraConfig\n\nbnb_config=BitsAndBytesConfig(\n load_in_4bit=True,\n bnb_4bit_quant_type=\"nf4\",\n bnb_4bit_compute_dtype=torch_dtype,\n bnb_4bit_use_double_quant=True\n# llm_int8_enable_fp32_cpu_offload=True\n)\n\npeft_config=LoraConfig(\n r=16,\n lora_alpha=32,\n lora_dropout=0.05,\n bias=\"none\",\n task_type=\"CAUSAL_LM\",\n target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']\n)","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:34:09.981626Z","iopub.execute_input":"2024-06-24T07:34:09.98199Z","iopub.status.idle":"2024-06-24T07:34:11.532906Z","shell.execute_reply.started":"2024-06-24T07:34:09.981957Z","shell.execute_reply":"2024-06-24T07:34:11.532023Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from transformers import AutoTokenizer\n\ntokenizer=AutoTokenizer.from_pretrained(os.getenv('MODEL_NAME'))","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:34:11.534097Z","iopub.execute_input":"2024-06-24T07:34:11.534548Z","iopub.status.idle":"2024-06-24T07:34:12.313303Z","shell.execute_reply.started":"2024-06-24T07:34:11.534519Z","shell.execute_reply":"2024-06-24T07:34:12.312329Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from transformers import AutoModelForCausalLM\n\nmodel=AutoModelForCausalLM.from_pretrained(\n os.getenv('MODEL_NAME'),\n quantization_config=bnb_config,\n # https://github.com/huggingface/trl/issues/1571#issuecomment-2075404536\n # https://github.com/xfactlab/orpo/issues/18\n device_map={\"\":0},\n torch_dtype=torch_dtype\n# attn_implementation=attn_implementation\n)\n\nmodel.device","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:34:12.314737Z","iopub.execute_input":"2024-06-24T07:34:12.315046Z","iopub.status.idle":"2024-06-24T07:35:27.09856Z","shell.execute_reply.started":"2024-06-24T07:34:12.315018Z","shell.execute_reply":"2024-06-24T07:35:27.097566Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def print_trainable_parameters(model):\n trainable_params=0\n all_params=0\n for _, param in model.named_parameters():\n all_params+=param.numel()\n if param.requires_grad:\n trainable_params+=param.numel()\n print(f\"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params/all_params:.2f}\")\n\nprint_trainable_parameters(model)","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:35:27.099915Z","iopub.execute_input":"2024-06-24T07:35:27.100297Z","iopub.status.idle":"2024-06-24T07:35:27.109738Z","shell.execute_reply.started":"2024-06-24T07:35:27.100262Z","shell.execute_reply":"2024-06-24T07:35:27.108807Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Set chat format and feeze pretrained weights","metadata":{}},{"cell_type":"code","source":"from trl import setup_chat_format\nfrom peft import prepare_model_for_kbit_training\n\nmodel, tokenizer=setup_chat_format(model, tokenizer)\n\nmodel=prepare_model_for_kbit_training(model)","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:35:27.111059Z","iopub.execute_input":"2024-06-24T07:35:27.111473Z","iopub.status.idle":"2024-06-24T07:35:27.229318Z","shell.execute_reply.started":"2024-06-24T07:35:27.111431Z","shell.execute_reply":"2024-06-24T07:35:27.228285Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"print_trainable_parameters(model)","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:35:27.23062Z","iopub.execute_input":"2024-06-24T07:35:27.230892Z","iopub.status.idle":"2024-06-24T07:35:27.237728Z","shell.execute_reply.started":"2024-06-24T07:35:27.230869Z","shell.execute_reply":"2024-06-24T07:35:27.236882Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Loading Dataset","metadata":{}},{"cell_type":"code","source":"from datasets import load_dataset\n\n# Note: if you have enough computing resource, please considering use all data for your training.\n# ds=load_dataset(os.getenv('DATASET'), split='all')\n# ds=ds.shuffle(seed=42).select(range(1000))\n\n\nds=load_dataset(os.getenv('DATASET'), split='train[:300]')\nds","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:35:27.239128Z","iopub.execute_input":"2024-06-24T07:35:27.239432Z","iopub.status.idle":"2024-06-24T07:35:32.79231Z","shell.execute_reply.started":"2024-06-24T07:35:27.239402Z","shell.execute_reply":"2024-06-24T07:35:32.791423Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"ds=ds.shuffle(seed=42)\n\ndef format_chat_template(row):\n row[\"chosen\"] = tokenizer.apply_chat_template(row[\"chosen\"], tokenize=False)\n row[\"rejected\"] = tokenizer.apply_chat_template(row[\"rejected\"], tokenize=False)\n return row\n\nds=ds.map(format_chat_template, num_proc=os.cpu_count())","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:35:32.793744Z","iopub.execute_input":"2024-06-24T07:35:32.794892Z","iopub.status.idle":"2024-06-24T07:35:33.057679Z","shell.execute_reply.started":"2024-06-24T07:35:32.794853Z","shell.execute_reply":"2024-06-24T07:35:33.056652Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"ds=ds.train_test_split(test_size=0.01)\nds","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:35:33.05891Z","iopub.execute_input":"2024-06-24T07:35:33.059187Z","iopub.status.idle":"2024-06-24T07:35:33.085316Z","shell.execute_reply.started":"2024-06-24T07:35:33.059162Z","shell.execute_reply":"2024-06-24T07:35:33.084425Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Fine-tuning\n\nWe need to set a few hyperparameter for ORPO configuration.\n\n### learning_rate\n\nORPO uses very low learning rates compared to traditinal SFT or even DPO. This value of 8e-6 comes from the original paper. SFT is 1e-5, DPO is 5e-6.\n\n### beta\n\nIt is the $\\lambda\\$ parameter in the paper, with the default value of 0.1\n\n### max_lengthm batch_size\n\nOther parameters, like `max_length` and batch size are set to use as much VRAM as avaliable(~20 GB).","metadata":{}},{"cell_type":"code","source":"from trl import ORPOConfig, ORPOTrainer\n\n# https://github.com/huggingface/trl/blob/v0.8.6/trl/trainer/orpo_config.py\norpo_args=ORPOConfig(\n learning_rate=8e-6,\n beta=0.1,\n lr_scheduler_type=\"linear\",\n max_length=1024,\n max_prompt_length=512,\n per_device_train_batch_size=2,\n per_device_eval_batch_size=2,\n gradient_accumulation_steps=4,\n optim=\"paged_adamw_8bit\",\n num_train_epochs=1,\n evaluation_strategy=\"steps\",\n eval_steps=0.2,\n logging_steps=1,\n warmup_steps=10,\n report_to=\"wandb\",\n run_name=os.getenv('WANDB_NAME'),\n output_dir=os.getenv('WANDB_NAME')\n)\n\ntrainer=ORPOTrainer(\n model=model,\n args=orpo_args,\n train_dataset=ds[\"train\"],\n eval_dataset=ds[\"test\"],\n peft_config=peft_config,\n tokenizer=tokenizer\n)\n\ntrainer.train()","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:35:33.089413Z","iopub.execute_input":"2024-06-24T07:35:33.089758Z","iopub.status.idle":"2024-06-24T07:36:15.293162Z","shell.execute_reply.started":"2024-06-24T07:35:33.089725Z","shell.execute_reply":"2024-06-24T07:36:15.291245Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"kwargs={\n 'model_name': os.getenv(\"WANDB_NAME\"),\n 'finetuned_from': os.getenv('MODEL_NAME'),\n# 'tasks': '',\n# 'dataset_tags':'',\n 'dataset': os.getenv(\"DATASET\")\n}\n\ntokenizer.push_to_hub(os.getenv(\"WANDB_NAME\"))\ntrainer.push_to_hub(**kwargs)","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:36:15.294112Z","iopub.status.idle":"2024-06-24T07:36:15.294609Z","shell.execute_reply.started":"2024-06-24T07:36:15.294343Z","shell.execute_reply":"2024-06-24T07:36:15.294363Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Merge and push merged model","metadata":{}},{"cell_type":"code","source":"import gc\n\ndel trainer, model\ngc.collect()\n\ntorch.cuda.empty_cache()","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:36:15.296217Z","iopub.status.idle":"2024-06-24T07:36:15.296687Z","shell.execute_reply.started":"2024-06-24T07:36:15.296456Z","shell.execute_reply":"2024-06-24T07:36:15.296474Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"tokenizer=AutoTokenizer.from_pretrained(os.getenv('MODEL_NAME'))","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:36:15.297636Z","iopub.status.idle":"2024-06-24T07:36:15.298067Z","shell.execute_reply.started":"2024-06-24T07:36:15.297843Z","shell.execute_reply":"2024-06-24T07:36:15.297861Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model = AutoModelForCausalLM.from_pretrained(\n os.getenv('MODEL_NAME'),\n low_cpu_mem_usage=True,\n return_dict=True,\n torch_dtype=torch.float16,\n device_map=\"cuda\",\n)\n\nmodel.device()","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:36:15.300031Z","iopub.status.idle":"2024-06-24T07:36:15.300527Z","shell.execute_reply.started":"2024-06-24T07:36:15.300267Z","shell.execute_reply":"2024-06-24T07:36:15.300284Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model, tokenizer = setup_chat_format(model, tokenizer)\n\n# Merge adapter with base model\nmodel = PeftModel.from_pretrained(model, os.getenv(\"WANDB_NAME\"))\nmodel = model.merge_and_unload()","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:36:15.302286Z","iopub.status.idle":"2024-06-24T07:36:15.302654Z","shell.execute_reply.started":"2024-06-24T07:36:15.302474Z","shell.execute_reply":"2024-06-24T07:36:15.302494Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# model.push_to_hub(os.getenv(\"WANDB_NAME\"), use_temp_dir=False)\n# tokenizer.push_to_hub(os.getenv(\"WANDB_NAME\"), use_temp_dir=False)","metadata":{"execution":{"iopub.status.busy":"2024-06-24T07:36:15.303991Z","iopub.status.idle":"2024-06-24T07:36:15.304324Z","shell.execute_reply.started":"2024-06-24T07:36:15.304162Z","shell.execute_reply":"2024-06-24T07:36:15.304175Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Acknowledge\n\n* https://www.kaggle.com/code/aisuko/fine-tuning-phi-2-with-qlora\n* https://medium.com/towards-data-science/fine-tune-llama-3-with-orpo-56cfab2f9ada\n* https://www.kaggle.com/code/aisuko/llm-prompt-recovery-with-gemma","metadata":{}}]}