-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFine-tuning Mistral-7b with DPO
1 lines (1 loc) · 19.1 KB
/
Fine-tuning Mistral-7b with DPO
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30627,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Overview\n\n**Note: All the images are from the Credits section or the internet. And this one was not reviewed completely because no enough computing resources, and bf16 type is not supported by current env.**\n\nPre-trained LLMs can only perform next-token prediction, making them unable to answer questions. This is why these base models are then fine-tuned on pairs of instructions and anwers to act as helpful assistants. However, this process can still be flawed: fine-tuned LLMs can be biased, toxic, harmful, etc. This is where Reinforcement Learning from RLHF(Reinforcement Learning from Human Feedback) comes into play.\n\nRLHF provides different answers to the LLM, which are ranked according to a desired behavior(helpfulness, toxicity, etc). The model learns to output the best answer among these candidates, hence mimicking the behavior we want to instill. Often seen as a way to censor models, this process has recently become popular for improving performance, as shown in [neural-chat-7b-v3-1](https://huggingface.co/Intel/neural-chat-7b-v3-1).\n\nWe are going to fine-tune OpenHermes-2.5 using a RLHF-like technique: **Direct Preference Optimization(DPO)**.","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19"}},{"cell_type":"code","source":"%%capture\n!pip install transformers==4.36.2\n!pip install datasets==2.15.0\n!pip install peft==0.7.1\n!pip install bitsandbytes==0.41.3\n!pip install trl==0.7.7","metadata":{"execution":{"iopub.status.busy":"2024-08-30T16:09:54.352046Z","iopub.execute_input":"2024-08-30T16:09:54.352441Z","iopub.status.idle":"2024-08-30T16:11:02.405575Z","shell.execute_reply.started":"2024-08-30T16:09:54.352396Z","shell.execute_reply":"2024-08-30T16:11:02.404262Z"},"trusted":true},"execution_count":2,"outputs":[]},{"cell_type":"code","source":"import os\nfrom huggingface_hub import login\nfrom kaggle_secrets import UserSecretsClient\n\nuser_secrets = UserSecretsClient()\n\nlogin(token=user_secrets.get_secret(\"HUGGINGFACE_TOKEN\"))\n\nos.environ[\"WANDB_API_KEY\"]=user_secrets.get_secret(\"WANDB_API_KEY\")\nos.environ[\"WANDB_PROJECT\"] = \"Fine-tune-models\"\nos.environ[\"WANDB_NOTES\"] = \"Fine tune model distilbert base uncased\"\nos.environ[\"WANDB_NAME\"] = \"ft-openhermes-25-mistral-7b-irca-dpo-pairs\"","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:51:47.510911Z","iopub.execute_input":"2024-02-23T04:51:47.511251Z","iopub.status.idle":"2024-02-23T04:51:48.553236Z","shell.execute_reply.started":"2024-02-23T04:51:47.511223Z","shell.execute_reply":"2024-02-23T04:51:48.552217Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Preparing datasets\n\nWe are prefer the **Preference datasets**. Typically consist of a collection of answers that are ranked by humans. This ranking is seential, as the RLHF process fine-tunes LLMs to output the preferred answer.","metadata":{}},{"cell_type":"code","source":"from datasets import load_dataset\n\ntrain_ds=load_dataset(\"Intel/orca_dpo_pairs\")[\"train\"]","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:51:48.554442Z","iopub.execute_input":"2024-02-23T04:51:48.554767Z","iopub.status.idle":"2024-02-23T04:51:54.192161Z","shell.execute_reply.started":"2024-02-23T04:51:48.554738Z","shell.execute_reply":"2024-02-23T04:51:54.191428Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"The structure of the dataset is straightforward: for each row, there is one chosen(preferred) answer, and one rejected answer. The goal of RLHF is to guide the model to output the preferred answer.","metadata":{}},{"cell_type":"markdown","source":"## Preference datasets\n\nThey are notoriously costly and difficult to make, as they require collecting manual feedback from humans. This feedback is also subjective and can easily be biased toward confident(but wrong) answers or contradict itself (different annotators have different values). Over time, several solutions have been proposed to tackle these issues, such as replacing human feedback with AI feedback [RLAIF](https://arxiv.org/abs/2212.08073).\n\nThese datasets also tend to be a lot smaller than fine-tuninf datasets. To illustrate this, the excellent [neural-chat-7b-v3-1](https://huggingface.co/Intel/neural-chat-7b-v3-1) uses 518k samples for fine-tuning [Open-Orca/SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca) but only 12.9k samples for RLHF [Intel/orca_dpo_pairs](https://huggingface.co/datasets/Intel/orca_dpo_pairs). In this case, the authors generated answers with GPT-4/3.5 to create the preferred answers, and with [Llama 2 13b chat](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) to create the rejected responses. It's a smart way to bypass human feedback and only rely on models with different levels of performance.","metadata":{}},{"cell_type":"markdown","source":"# Direct Preference Optimization\n\nWhile the concept of RLHF has been used in robotics for a long time, it was popularized for LLMs in [Fine-tuning Language Models from Human Preferences](https://arxiv.org/pdf/1909.08593.pdf). In this paper, the authors present a framework where a reward model is trained to approximate human feedback. This reward model is then used to optimize the fine-tuned model's policy using the [Proximal Policy Optimization algorithm](https://arxiv.org/abs/1707.06347).\n\n<div style=\"text-align: center\"><img src=\"https://files.mastodon.social/media_attachments/files/111/707/466/912/510/214/original/5f3fa6b5a0b186a3.webp\" width=\"60%\" heigh=\"60%\" alt=\"proximal policy optimization\"></div>\n\n\nThe Core concept of PPO revolves around making smaller, incremental updates to the policy, as larger updates can lead to instability or suboptimal solutions. From experience, this technique is unfortunately still unstable(loss diverges), difficult to reproduce(numerous hyperparameters, sensitive to random seeds), and computationally expensive.\n\nThis is where Direct Preference Optimization(DPO) comes into play. DPO simplifies control by treating the task as a classification problem. Concretely, it uses two models: the **trained model(or policy model)** and a copy of it called the **reference model**. During training, the goal is to make sure the trained model outputs higher probabilities for preferred answers than the reference model. Conversely, we also want it output lower probabilities for rejected answers. It means we're penalizing the LLM for bad answers and rewarding it for good ones.\n\n<div style=\"text-align: center\"><img src=\"https://files.mastodon.social/media_attachments/files/111/707/528/109/748/161/original/7c1ef620f48dd4a6.webp\" width=\"60%\" heigh=\"60%\" alt=\"direct preference optimization\"></div>\n\nBy using the LLM itself as a reward model and employing binary cross-entropy objectives, DPO efficiently aligns the model's outputs with human preferences without the need for extensive sampling, reward model fitting, or intricate hyperparameters adjustments. It results in a more stable, more efficient, and computationally less demanding process.","metadata":{}},{"cell_type":"markdown","source":"# Formatting the data\n\nHere we are going to fine-tune [OpenHermes-2.5-Mistral-7B](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B) which is a Mistral-7b model that was only supervised fine-tuned. To this end, we will use the [Intel/orca_dpo_pairs](https://huggingface.co/datasets/Intel/orca_dpo_pairs) dataset to align our model and improve this performance.","metadata":{}},{"cell_type":"code","source":"from peft import LoraConfig, PeftModel, get_peft_model\nfrom trl import DPOTrainer","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:51:54.193221Z","iopub.execute_input":"2024-02-23T04:51:54.193678Z","iopub.status.idle":"2024-02-23T04:52:19.183556Z","shell.execute_reply.started":"2024-02-23T04:51:54.193645Z","shell.execute_reply":"2024-02-23T04:52:19.182486Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model_name=\"teknium/OpenHermes-2.5-Mistral-7B\"","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:52:19.187501Z","iopub.execute_input":"2024-02-23T04:52:19.189157Z","iopub.status.idle":"2024-02-23T04:52:19.19602Z","shell.execute_reply.started":"2024-02-23T04:52:19.189006Z","shell.execute_reply":"2024-02-23T04:52:19.194074Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"OpenHermes-2.5-Mistral-7B uses a specific chat template, called [ChatML](https://huggingface.co/docs/transformers/chat_templating). Here is an example of conversation formatted with this template:\n\n```\n<|im_start|>system\nYou are a helpful chatbot that will do its best not to say anything so stupid that people tweet about it.<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great!<|im_end|>\n```\n\nEach of sentence which includes a role, like \"system, \"user\" or \"assistant\", and it appends special tokens at the beginning and the end of the sentence(<|im_start|> and <|im_end|>). These two are used to separate different sentences. Moreover, DPOTrainer also requires a specific format with three columns: prompt, chosen and rejected.\n\nWe will simply concatenate the system and question columns to the prompt column. We will also map the chatgpt column to \"chosen\" and llama2-13b-chat to \"rejected\". To format the dataset in a reliable way, we will use the tokenizer's `apply_chat_template()` function, which already use ChatML.","metadata":{}},{"cell_type":"code","source":"from transformers import AutoTokenizer\n\ndef chatml_format(example):\n # format system\n if len(example['system'])>0:\n message ={\"role\":\"system\",\"content\":example['system']}\n system=tokenizer.apply_chat_template([message], tokenize=False)\n else:\n system=\"\"\n \n #format instruction\n message={\"role\":\"user\",\"content\":example['question']}\n prompt=tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n \n # format chosen answer\n chosen =example['chosen']+\"<im_end>\\n\"\n \n # format rejected answer\n rejected = example['rejected']+\"<im_end>\\n\"\n \n return {\n \"prompt\": system+prompt,\n \"chosen\": chosen,\n \"rejected\": rejected,\n }\n\n# we have load datasets in above section\n\n# save columns\noriginal_columns=train_ds.column_names\n\n# tokenizer\ntokenizer=AutoTokenizer.from_pretrained(model_name)\ntokenizer.pad_token=tokenizer.eos_token\ntokenizer.padding_side=\"left\"\n\n#format dataset\ntrain_dataset=train_ds.map(\n function=chatml_format,\n remove_columns=original_columns\n)\n\n# checking only one example\ntrain_dataset[0]","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:52:19.197423Z","iopub.execute_input":"2024-02-23T04:52:19.198326Z","iopub.status.idle":"2024-02-23T04:52:23.294915Z","shell.execute_reply.started":"2024-02-23T04:52:19.198291Z","shell.execute_reply":"2024-02-23T04:52:23.293951Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Training the model with DPO\n\nWe define the LoRA configurations to train the model. We set the rank value to be equal to the `lora_alpha`, which is unusual(2*`r` as a rule of thumb). We also target all the linear modules with adapters.","metadata":{}},{"cell_type":"code","source":"from transformers import BitsAndBytesConfig\nfrom peft import LoraConfig, PeftModel, get_peft_model\n\nimport torch\n\n# LoRA configuration\npeft_config=LoraConfig(\n r=16,\n lora_alpha=16,\n lora_dropout=0.05,\n bias=\"none\",\n task_type=\"CAUSAL_LM\",\n target_modules=['k_proj','gate_proj','v_proj','up_proj','q_proj','o_proj','down_proj']\n)\n\nbnb_config=BitsAndBytesConfig(\n load_in_4bit=True,\n bnb_4bit_quant_type=\"nf4\",\n bnb_4bit_use_double_quant=True,\n bnb_4bit_compute_dtype=torch.float16,\n# llm_int8_enable_fp32_cpu_offload=True,\n)","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:52:23.296158Z","iopub.execute_input":"2024-02-23T04:52:23.296445Z","iopub.status.idle":"2024-02-23T04:52:23.304542Z","shell.execute_reply.started":"2024-02-23T04:52:23.29642Z","shell.execute_reply":"2024-02-23T04:52:23.303581Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Quantize the model\n\nWe're now ready to load the model we want to fine-tune with DPO. In this case, two models are required: the model to fine-tune as well as the reference model. This is mostly for the sake of readability, as the `DPOTrainer` object automatically creates a ference model if none is provided.","metadata":{}},{"cell_type":"code","source":"from transformers import AutoModelForCausalLM\n\nmodel=AutoModelForCausalLM.from_pretrained(\n model_name,\n device_map=\"auto\",\n quantization_config=bnb_config,\n torch_dtype=torch.float16,\n)\n\nmodel","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:52:23.3065Z","iopub.execute_input":"2024-02-23T04:52:23.306979Z","iopub.status.idle":"2024-02-23T04:54:23.805055Z","shell.execute_reply.started":"2024-02-23T04:52:23.30694Z","shell.execute_reply":"2024-02-23T04:54:23.803922Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model.config.use_cache=False\nmodel=get_peft_model(model, peft_config)\nmodel.get_memory_footprint()\nmodel.config","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:54:23.806703Z","iopub.execute_input":"2024-02-23T04:54:23.807115Z","iopub.status.idle":"2024-02-23T04:54:24.541892Z","shell.execute_reply.started":"2024-02-23T04:54:23.807078Z","shell.execute_reply":"2024-02-23T04:54:24.540748Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Reference model\nref_model=AutoModelForCausalLM.from_pretrained(\n model_name,\n device_map=\"auto\",\n quantization_config=bnb_config,\n torch_dtype=torch.float16,\n)\nref_model=get_peft_model(ref_model, peft_config)\nref_model.get_memory_footprint()\nref_model.config","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:54:24.543248Z","iopub.execute_input":"2024-02-23T04:54:24.543556Z","iopub.status.idle":"2024-02-23T04:54:39.86197Z","shell.execute_reply.started":"2024-02-23T04:54:24.543529Z","shell.execute_reply":"2024-02-23T04:54:39.860979Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"The final step consists of providing all the hyperparameters to `TrainingArguments` and `DPOTrainer`:\n* Among them, the `beta` parameter is unique to DPO since it controls the divergence from the initial policy(0.1 is a typical value for it).","metadata":{}},{"cell_type":"code","source":"from transformers import TrainingArguments\nfrom trl import DPOTrainer\n\ntraining_args=TrainingArguments(\n per_device_train_batch_size=2,\n # key word argumnts to be passed to the gradient_checkingpointing_enable method\n gradient_checkpointing=True,\n gradient_accumulation_steps=5,\n remove_unused_columns=False,\n learning_rate=5e-5,\n lr_scheduler_type=\"cosine\",\n # the total number of training steps to perform\n max_steps=100,\n save_strategy=\"no\",\n logging_steps=100,\n output_dir=os.getenv(\"WANDB_NAME\"),\n optim=\"paged_adamw_32bit\",\n bf16=False, # Doesn;t support in Kaggle environment\n fp16=True,\n # Number of steps used for a linear warmup from 0 to learning_rate\n warmup_steps=50,\n run_name=os.getenv(\"WANDB_NAME\"),\n report_to=\"wandb\"\n)\n\ndpo_trainer=DPOTrainer(\n model,\n ref_model,\n args=training_args,\n train_dataset=train_dataset,\n tokenizer=tokenizer,\n peft_config=peft_config,\n beta=0.1,\n max_prompt_length=1024,\n max_length=1536,\n)\n\n# Kaggle env does not have enough Memory to doing training, it will restarted and cause error.\ndpo_trainer.train()","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:55:14.957163Z","iopub.execute_input":"2024-02-23T04:55:14.957558Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Saving the merged model","metadata":{}},{"cell_type":"code","source":"dpo_trainer.model.save_pretrained(\"final_checkpoint\")\ntokenizer.save_pretrained(\"final_checkpoint\")\n\n# Flush memory\ndel dpo_trainer, model, ref_model\ngc.collect()\ntorch.cuda.empty_cache()\n\n# Reload model in FP16(instead of NF4)\nbase_model=AutoModelForCausalLM.from_pretrained(\n model_name,\n return_dict=True,\n torch_type=torch.float16,\n)\n\ntokenizer=AutoTokenizer.from_pretrained(model_name)\n\n# merge base model with the adapter\nmodel=PeftModel.from_pretrained(base_model, \"final_checkpoint\")\nmodel=model.merge_and_unload()\n\n# save model and tokenizer\nmodel.save_pretrained(os.getenv(\"WANDB_NAME\"))\ntokenizer.save_pretrained(os.getenv(\"WANDB_NAME\"))\n\n\nmodel.push_to_hub(os.getenv(\"WANDB_NAME\"))\ntokenizer.push_to_hub(os.getenv(\"WANDB_NAME\"))","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:54:39.872836Z","iopub.status.idle":"2024-02-23T04:54:39.873289Z","shell.execute_reply.started":"2024-02-23T04:54:39.873048Z","shell.execute_reply":"2024-02-23T04:54:39.873069Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Checking the model performs\n\nExcepting run the model locally, we can also leverage the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) to evaluate it. As the process is quite resource-intensive, we can also directly submit it for evaluation on the [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard).","metadata":{}},{"cell_type":"code","source":"# format prompt\nmessage=[\n {\"role\":\"system\", \"content\":\"The weather is important.\"},\n {\"role\":\"user\",\"content\":\"Does Melborune raining today?\"}\n]\n\ntokenizer=AutoTokenizer.from_pretrained(os.getenv(\"WANDB_NAME\"))\nprompt=tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)\n\npipe=transformers.pipeline(\n \"text-generation\",\n model=os.getenv(\"WANDB_NAME\"),\n tokenizer=tokenizer\n)\n\n# generate text\nsequences=pipe(\n prompt,\n do_sample=True,\n temperature=0.7,\n top_p=0.9,\n num_return_sequences=1,\n max_length=200,\n)\n\nsequences[0]['generated_text']","metadata":{"execution":{"iopub.status.busy":"2024-02-23T04:54:39.874427Z","iopub.status.idle":"2024-02-23T04:54:39.874846Z","shell.execute_reply.started":"2024-02-23T04:54:39.874637Z","shell.execute_reply":"2024-02-23T04:54:39.874659Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Credits\n\n* https://towardsdatascience.com/fine-tune-a-mistral-7b-model-with-direct-preference-optimization-708042745aac","metadata":{}}]}