Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 76 additions & 45 deletions recipe/spin/README.md
Original file line number Diff line number Diff line change
@@ -1,40 +1,62 @@
# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models (verl Recipe)
# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models

This repository hosts a `verl` recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). The implementation uses an **Online Direct Preference Optimization (Online DPO)** approach for language model alignment. This method allows a model to iteratively improve its capabilities by learning from preferences generated using its own outputs, potentially reducing reliance on external preference datasets or stronger teacher models.
This repository hosts a `verl` recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory.

Paper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\*, [Yihe Deng](https://github.com/uclaml/SPIN)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)
**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models:

verl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20)
1. **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations.
2. **Two-Player Game Setup:** A game involving two players acted by a single LLM.
3. **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration.

Paper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\*, [Yihe Deng](https://github.com/uclaml/SPIN)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)

[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)]

## Algorithm: Online DPO Inspired by SPIN
verl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20)

This recipe implements an Online DPO algorithm adapted to the `verl` Reinforcement Learning framework, drawing inspiration from concepts presented in SPIN. It provides an alternative to PPO for fine-tuning language models.
---

**Core Idea:** Instead of maximizing a scalar reward signal, this approach directly optimizes the policy model to align with preference data generated *online* during training:
## Key Function (compute_online_dpo_loss) and Related works
SPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023).

1. **Generation:** The current policy model (actor) generates two (or more) responses for each prompt in a batch.
2. **Preference Labeling:** A reward model or reward function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected).
3. **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using the DPO loss function, comparing against a reference model.
This `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data.

**Connection to SPIN:**
While this recipe uses the DPO loss, the online generation loop where the current model generates data used for its own update shares conceptual similarities with the self-play idea in SPIN. The periodic update of the reference model (potentially using weights from the actor) further aligns with SPIN's iterative self-improvement concepts.
Specifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets.

**Reference Papers:**
* **SPIN:** [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024)
* **DPO:** [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023)
* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024)
* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023)
* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023)
* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023)
* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024)
* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024)


## Our Online DPO Implementation

Our `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include:

## Implementation within verl
The recipe is expected to be working on verl v0.3.0.post1
* **No Critic:** Unlike PPO, we omit the value function critic.
* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline.
* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems).
* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences.
* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles.

This implementation adapts the existing PPO infrastructure provided by `verl`:
---
## Algorithm

* **No Critic:** The value function critic model used in PPO is not required and is omitted.
* **Reference Model:** An explicit reference policy model (`ref_policy_wg`) is maintained and used in the DPO loss calculation. This implementation allows for periodically updating the reference model's weights from the actor model (controlled by `ref_update_freq`).
* **Preference Calculation:** Logic (`compute_onlineDPO_pref` in `core_algos.py`) determines chosen/rejected pairs based on scores from a reward source.
* **DPO Loss:** The PPO policy loss and advantage calculations are replaced with the DPO loss computation (`compute_online_dpo_loss` in `core_algos.py`) within the actor update step (`dp_actor.py`).
* **Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the training loop: generation, preference labeling, optional reference model updates, and policy updates via the DPO loss.
This recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models.

**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training:

1. **Generation:** The current model generates multiple responses for each prompt in a batch.
2. **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem).
3. **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model.

**Connection with SPIN:**
Instead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about "dynamically changing target data distribution" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling.

---

## Reproduce the Experiment (Example Setup)

Expand All @@ -43,23 +65,23 @@ The following steps outline how to set up the environment and run the SPIN recip
1. **Setup Environment (Example using Docker):**
```bash
# Start a container with GPU access and shared memory
# docker run -it --name spin_test --gpus all \
# --shm-size=32g \
# --ipc=host \
# -v /path/to/host/.cache:/root/.cache \
# -e HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN> \
# lmsysorg/sglang:latest \
# /bin/bash
docker run -it --name spin_test --gpus all \
--shm-size=32g \
--ipc=host \
-v /path/to/host/.cache:/root/.cache \
-e HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN> \
lmsysorg/sglang:latest \
/bin/bash

# Inside the container or on your host machine:
# Ensure /tmp is writable
mkdir -p /tmp
chmod 1777 /tmp

# Install Python 3.10 (if not present) and venv
# sudo apt update
# sudo apt install -y python3.10 python3.10-venv tmux
# python3 -m ensurepip --upgrade
sudo apt update
sudo apt install -y python3.10 python3.10-venv tmux
python3 -m ensurepip --upgrade

# Create and activate a virtual environment
python3 -m venv ~/.python/spin_env
Expand All @@ -73,7 +95,7 @@ The following steps outline how to set up the environment and run the SPIN recip
```bash
# Clone the verl repository and checkout the spin branch
cd ~
git clone git@github.com:volcengine/verl.git](git@github.com:volcengine/verl.git) && cd verl
git clone git@github.com:volcengine/verl.git && cd verl

# Install flash-attn (handle potential build issues)
python3 -m uv pip install wheel packaging
Expand Down Expand Up @@ -111,6 +133,8 @@ The following steps outline how to set up the environment and run the SPIN recip
bash recipe/spin/run_spin.sh
```

---

## Configuration

* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`).
Expand All @@ -121,28 +145,35 @@ The following steps outline how to set up the environment and run the SPIN recip
* `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`.
* `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor).

---

## Key Files

* `main_spin.py`: Main entry point using Hydra to load config and launch the `SpinTrainer`.
* `spin_trainer.py`: Defines the `SpinTrainer` class orchestrating the Online DPO training loop.
* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`.
* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop.
* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP.
* `dp_actor.py`: Contains the actor class, including the DPO policy update logic.
* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`.
* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe.
* `run_spin.sh` (or similar): Example bash script for launching a training run.
* `README.md`: This file.

---

## Acknowledgement

We sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO):

- [Yue Wu](https://yuewu.us/)
- [Yuhao Yang](https://github.com/yhyang201)
- [Yifan Zhang](https://github.com/yifanzhang-pro)
- [Yongan Xiang](https://github.com/BearBiscuit05)
- [Junrong Lin](https://github.com/ocss884)
- [Yuxuan Tong](https://github.com/tongyx361)
- [Guangming Shen](https://github.com/PeterSH6)
- [Biao He](https://www.linkedin.com/in/biao-he/)
- [Qingquan Song](https://qingquansong.github.io/)
- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)
* [Zixiang Chen](https://sites.google.com/view/zxchen)
* [Yuhao Yang](https://github.com/yhyang201)
* [Yifan Zhang](https://github.com/yifanzhang-pro)
* [Yongan Xiang](https://github.com/BearBiscuit05)
* [Junrong Lin](https://github.com/ocss884)
* [Yuxuan Tong](https://github.com/tongyx361)
* [Guangming Shen](https://github.com/PeterSH6)
* [Biao He](https://www.linkedin.com/in/biao-he/)
* [Qingquan Song](https://qingquansong.github.io/)
* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/)
* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)

---