Skip to content

Direct Preference Optimization

OpenMOSE edited this page Jan 19, 2025 · 9 revisions

Direct Preference Optimization (DPO) in RWKV-LM-RLHF

Overview

Direct Preference Optimization (DPO) is an advanced training method that optimizes language models based on human preferences. Unlike traditional reinforcement learning approaches, DPO directly learns from paired examples of preferred and non-preferred responses.

Technical Details

DPO works by:

  • Optimizing a model to maximize the likelihood of preferred responses while minimizing the likelihood of non-preferred ones
  • Using implicit reward modeling through preference pairs
  • Implementing a more stable training process compared to traditional RLHF methods
  • Avoiding the need for separate reward model training

Simplified Explanation

Think of DPO like teaching a student by showing them two different answers to the same question - one good and one not so good. Instead of giving specific grades, we just show which answer is better. The student learns to give answers more like the good ones and less like the not-so-good ones.

System Requirements

Hardware Requirements

  • Recommended GPU: 24GB VRAM (RTX3090, RTX4090, AMD MI100)
    • DPO requires approximately 2x the compute power and VRAM compared to standard SFT
    • Note: 16GB GPUs might work, but this is untested

Software Requirements

  • Operating System: Ubuntu 22.04 or 24.04
    • Recommendation: Disable Wayland for training stability

Implementation Guide

1. Model Preparation

RWKV-LM-RLHF supports:

  • RWKV v6 (Finch)
  • RWKV v7 (Goose)

Note: Currently no cross-compatibility between versions

How to try

Reference Model

For this guide, we use RWKV v6 1.6B model: Download Link

2. Dataset Preparation

Format Requirements

The dataset should be structured as a CSV file containing three columns:

  • Prompt
  • Chosen (preferred response)
  • Reject (non-preferred response)

Example CSV Format

prompt,chosen,reject
who are you?,i'm RWKV whats up?,'i'm an AI Assistant. how can i help you?

Dataset Examples

Sample Datasets

RWKV-LM-RLHF provides two types of sample datasets:

  1. Complete Dataset (with reject responses)
    Contains Prompt, Chosen, and Reject responses
    View Sample

  2. Base Dataset (without reject responses)
    Contains only Prompt and Chosen responses
    View Sample

Generating Reject Responses

The repository includes a utility to automatically generate reject responses using the base model. This is useful when you only have preferred responses available.

Command Structure

python rlhf_generate_reject_csv.py --load_model 'myfolder/models/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' \
 --input_csv 'example/DPO/input_csv/rlhf_example_dataset.csv' \
 --output_csv 'example/DPO/output_csv/rlhf_example_dataset_withreject.csv' \
 --strategy 'cuda fp16' 

Parameters Explanation

  • --load_model: Path to the model used for generating reject responses
  • --input_csv: Path to input CSV file
    • Must contain columns: prompt, chosen, reject (reject can be empty)
  • --output_csv: Path where the generated dataset will be saved
  • --strategy: Model inference strategy
    • Default: 'cuda fp16'
    • For larger models (e.g., 14B): Use 'cuda fp16i8' to fit within VRAM constraints

Note: Ensure your input CSV maintains the required column structure (prompt, chosen, reject) even if the reject column is empty.

1. Data Preprocessing

First, we need to tokenize the dataset and generate reference logits:

python rlhf_dpo_generate_save.py  --load_model 'myfolder/models/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' \
 --input_csv 'example/DPO/output_csv/rlhf_example_dataset_withreject.csv' \
 --output_save 'example/DPO/output_save/rlhf_example_dataset.save' \
 --target_pair_count 60

Parameters Explanation

  • --load_model: Path to the base model
  • --input_csv: Path to CSV file containing prompt, chosen, and reject columns
  • --output_save: Path for saving processed dataset (used for training)
  • --target_pair_count: Number of pairs to process (recommend: 2x the number of pairs in CSV)

2. Training Configuration

Launch training with the following command:

python train.py --load_model 'myfolder/models/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' \
 --wandb "RWKV-LM-RLHF 1B6-RLHF DPO" --proj_dir "myfolder/Outputs/1B6-RLHF-DPO"\
 --infctx 0 \
 --vocab_size 65536 --ctx_len 2048 \
 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
 --micro_bsz 1 --n_layer 24 --n_embd 2048\
 --lr_init 5e-6 --lr_final 1e-6 \
 --warmup_steps 100 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
 --accelerator gpu --devices 1 --precision bf16 \
 --grad_cp 1 --my_testing "x060" \
 --strategy deepspeed_stage_2_offload \
 --layer_profile 'layerprofile/24_TEST.csv' \
 --quant 1 \
 --quant_mode 'nf4'\
 --gpu_arch 'cuda' \
 --dpo 1 \
 --dpo_alpha 0.1 \
 --dpo_beta 0.01 \
 --rlhf_train_file 'example/DPO/output_save/rlhf_example_dataset.save' \
 --rlhf_max_corpus_len 1024
 --accumulate_grad_batches 16

Key Training Parameters

Essential Parameters

  • --load_model: Path to base model
  • --ctx_len: CUDA kernel context length (set to 2x rlhf_max_corpus_len)
  • --lr_init: Initial learning rate (should be very low for RLHF)
  • --lr_final: Final learning rate (typically 1/5 of initial rate)
  • --dpo_alpha: SFT loss ratio (0-1.0, recommended: 0.1-0.5)
  • --dpo_beta: DPO loss constant (0-1.0, recommended: 0.01-0.1)
  • --layer_profile: Detailed training strategy configuration check LayerProfile
  • --rlhf_max_corpus_len: Max Context Length(Prompt + chosen and Prompt + Reject)

Monitoring Training Progress

  1. Setup: Login to Weights & Biases (Wandb) for monitoring
  2. Key Metric: Watch the Pref-Percentage
    • Should start around 0.5
    • Gradually increase towards 1.0
    • Steady increase indicates successful training

Training Success Indicators

A successful DPO training typically shows:

  • Preference percentage starting at ~0.5
  • Gradual, consistent increase
  • Movement toward 1.0
  • No sudden jumps or unstable behavior

Note: The layer profile configuration details can be found in the referenced documentation.