Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/trl/stack_llama_2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model

## Prerequisites

Install all the dependencies in the `requirements.txt`:

```
$ pip install -U -r requirements.txt
```


## Training

There were two main steps to the DPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:

```
python ../../gaudi_spawn.py --world_size 8 --use_mpi sft_llama2.py \
--output_dir="./sft" \
--max_steps=500 \
--logging_steps=10 \
--save_steps=10 \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--bf16 \
--remove_unused_columns=False \
--run_name="sft_llama2" \
--report_to=none \
--use_habana \
--use_lazy_mode
```
2. Run the DPO trainer using the model saved by the previous step:
```
python ../../gaudi_spawn.py --world_size 8 --use_mpi dpo_llama2.py \
--model_name_or_path="sft/final_merged_checkpoint" \
--output_dir="dpo" \
--report_to=none
```


## Merging the adaptors

To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:

```
python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo" --output_name="stack-llama-2"
```

which will also push the model to your HuggingFace hub account.

## Running the model

We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:

```py
from peft import AutoPeftModelForCausalLM


model = AutoPeftModelForCausalLM.from_pretrained(
"dpo/final_checkpoint",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)

model.generate(...)
```
231 changes: 231 additions & 0 deletions examples/trl/stack_llama_2/dpo_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py, enable it for Gaudi2
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from optimum.habana import GaudiConfig, GaudiTrainingArguments
from optimum.habana.trl import GaudiDPOTrainer


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""

# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

# training parameters
model_name_or_path: Optional[str] = field(
default="../sft/results/final_checkpoint",
metadata={"help": "the location of the SFT model name or path"},
)
tokenizer_name_or_path: Optional[str] = field(
default="meta-llama/Llama-2-7b-hf",
metadata={"help": "the location of the SFT model name or path"},
)
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})

per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=4, metadata={"help": "the number of gradient accumulation steps"}
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "whether to use gradient checkpointing"}
)

lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})

max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})

output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})

# instrumentation
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default="wandb",
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)


def get_stack_exchange_paired(
data_dir: str = "data/rl",
sanity_check: bool = False,
cache_dir: str = None,
num_proc=24,
) -> Dataset:
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.

The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}

Prompts are structured as follows:
"Question: " + <prompt> + "\n\nAnswer: "
"""
dataset = load_dataset(
"lvwerra/stack-exchange-paired",
split="train",
cache_dir=cache_dir,
data_dir=data_dir,
)
original_columns = dataset.column_names

if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))

def return_prompt_and_responses(samples) -> Dict[str, str]:
return {
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
"chosen": samples["response_j"],
"rejected": samples["response_k"],
}

return dataset.map(
return_prompt_and_responses,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
# 1. initialize training arguments:
training_args = GaudiTrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
max_steps=script_args.max_steps,
logging_steps=script_args.logging_steps,
save_steps=script_args.save_steps,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
eval_steps=script_args.eval_steps,
output_dir=script_args.output_dir,
report_to=script_args.report_to,
lr_scheduler_type=script_args.lr_scheduler_type,
warmup_steps=script_args.warmup_steps,
optim=script_args.optimizer_type,
bf16=True,
remove_unused_columns=False,
run_name="dpo_llama2",
use_habana=True,
use_lazy_mode=True,
use_hpu_graphs_for_training=True,
use_hpu_graphs_for_inference=True,
)
# 2. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False

if script_args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

model_ref = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
model_ref.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

# 3. Load the Stack-exchange paired dataset
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
train_dataset = train_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)

# 4. Load evaluation dataset
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
eval_dataset = eval_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)

peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=[
"q_proj",
"v_proj",
"k_proj",
"out_proj",
"fc_in",
"fc_out",
"wte",
],
bias="none",
task_type="CAUSAL_LM",
)

gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True

# 5. initialize the DPO trainer
dpo_trainer = GaudiDPOTrainer(
model,
model_ref,
gaudi_config=gaudi_config,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config,
max_prompt_length=script_args.max_prompt_length,
max_length=script_args.max_length,
)

# 6. train
dpo_trainer.train()

# 7. save
dpo_trainer.save_model(script_args.output_dir)
50 changes: 50 additions & 0 deletions examples/trl/stack_llama_2/merge_peft_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py.
# only difference is removal of model.push_to_hub
from dataclasses import dataclass, field
from typing import Optional

import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser


@dataclass
class ScriptArguments:
"""
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
merged model.
"""

adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
assert script_args.base_model_name is not None, "please provide the name of the Base model"
assert script_args.output_name is not None, "please provide the output name of the merged model"

peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
if peft_config.task_type == "SEQ_CLS":
# The sequence classification task is used for the reward model in PPO
model = AutoModelForSequenceClassification.from_pretrained(
script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16
)
else:
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)

# Load the PEFT model
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
model.eval()

model = model.merge_and_unload()

model.save_pretrained(f"{script_args.output_name}")
tokenizer.save_pretrained(f"{script_args.output_name}")
# model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False)
5 changes: 5 additions & 0 deletions examples/trl/stack_llama_2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
trl == 0.7.6
peft == 0.6.2
datasets
wandb
tyro
Loading