Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
42e1170
Add WinRateCallback
lewtun Apr 29, 2024
52ae31e
Enable PairRM
lewtun Apr 29, 2024
2065c3f
Refactor
lewtun Apr 29, 2024
efa2e21
Streamline
lewtun Apr 29, 2024
27be35d
Add HF judge
lewtun Apr 30, 2024
ce4dad2
Add base judge
lewtun May 2, 2024
29c5f98
Use better prompt
lewtun May 2, 2024
b6b84ab
Merge branch 'main' into add-winrate-cb
lewtun May 3, 2024
ea60ebb
Clean
lewtun May 3, 2024
9542e73
Add max tokens
lewtun May 3, 2024
568e7b3
Use logging
lewtun May 3, 2024
9362eff
Add batched inference
lewtun May 3, 2024
e05e0ca
Squashed commit of the following:
qgallouedec Jul 15, 2024
f49a043
Merge branch 'main' into add-winrate-cb
qgallouedec Jul 15, 2024
6ae3993
judge refactoring and unittest
qgallouedec Jul 15, 2024
a2bfe57
format
qgallouedec Jul 15, 2024
3e8afdd
init
qgallouedec Jul 15, 2024
faaf38f
doc
qgallouedec Jul 15, 2024
ee0e4e9
Merge branch 'main' into add-winrate-cb
qgallouedec Jul 15, 2024
b1fb345
format
qgallouedec Jul 15, 2024
47624c3
improve doc
qgallouedec Jul 15, 2024
c89fdc7
basejudge
qgallouedec Jul 15, 2024
38a527f
improve doc and add BaseAPIJudge
qgallouedec Jul 17, 2024
721cb75
Doc
qgallouedec Jul 17, 2024
915147e
style
qgallouedec Jul 17, 2024
984d339
refactor callback
qgallouedec Jul 17, 2024
59b6ef3
remove openai and pairrm judge from test
qgallouedec Jul 17, 2024
a039f23
doc
qgallouedec Jul 18, 2024
7ddbc28
rm dpo online example
qgallouedec Jul 18, 2024
b68f2ab
new prompts and completions
qgallouedec Jul 18, 2024
0317393
skip hf judge and add hf token
qgallouedec Jul 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
title: ORPO Trainer
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: callbacks
title: Callback Classes
- local: judges
title: Judge Classes
- local: text_environments
title: Text Environments
title: API
Expand Down
13 changes: 13 additions & 0 deletions docs/source/callbacks.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Callbacks

## SyncRefModelCallback

[[autodoc]] SyncRefModelCallback

## RichProgressCallback

[[autodoc]] RichProgressCallback

## WinRateCallback

[[autodoc]] WinRateCallback
62 changes: 62 additions & 0 deletions docs/source/judges.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Judges

TRL provides judges to easily compare two completions.

Make sure to have installed the required dependencies by running:

```bash
pip install trl[llm_judge]
```

## Define your own judge

To define your own judge, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method that returns a list of 0/1 indicating which completion is better. Here is a dummy example where we define a simple judge that favors longer completions:

```python
from trl import BaseJudge

class LengthBasedJudge(BaseJudge):
def judge(self, prompts, completion_pairs, shuffle_order=False):
return [0 if len(c1) > len(c2) else 1 for c1, c2 in completion_pairs]
```

You can then use this judge as follows:

```python
judge = LengthBasedJudge()
judge.judge(
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
completion_pairs=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]],
) # Outputs: [1, 0]
```

TRL also provides a [`BaseAPIJudge`] class that can be used to define judges that interact with an API. You can subclass [`BaseAPIJudge`] and implement the [`BaseAPIJudge.get_response`] method that should return the response from the API. For an example, see the [`HuggingFaceJudge`] class.


## BaseJudge

[[autodoc]] BaseJudge

## BaseAPIJudge

[[autodoc]] BaseAPIJudge

## HuggingFaceJudge

[[autodoc]] HuggingFaceJudge

## MockAPIJudge

[[autodoc]] MockAPIJudge

## MockJudge

[[autodoc]] MockJudge

## OpenAIJudge

[[autodoc]] OpenAIJudge

## PairRMJudge

[[autodoc]] PairRMJudge
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"deepspeed": ["deepspeed>=0.9.5"],
"benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"],
"quantization": ["bitsandbytes<=0.41.1"],
"llm_judge": ["openai>=1.23.2", "huggingface_hub>=0.22.2", "llm-blender>=0.0.2"],
}
EXTRAS["dev"] = []
for reqs in EXTRAS.values():
Expand Down
33 changes: 33 additions & 0 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

from trl import HuggingFaceJudge, MockAPIJudge, MockJudge


class TestJudges(unittest.TestCase):
def _get_prompts_and_completion_pairs(self):
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completion_pairs = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
return prompts, completion_pairs

def test_mock_judge(self):
judge = MockJudge()
prompts, completion_pairs = self._get_prompts_and_completion_pairs()
ranks = judge.judge(prompts=prompts, completion_pairs=completion_pairs)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))

def test_mock_api_judge(self):
judge = MockAPIJudge()
prompts, completion_pairs = self._get_prompts_and_completion_pairs()
ranks = judge.judge(prompts=prompts, completion_pairs=completion_pairs)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))

@unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.")
def test_hugging_face_judge(self):
judge = HuggingFaceJudge()
prompts, completion_pairs = self._get_prompts_and_completion_pairs()
ranks = judge.judge(prompts=prompts, completion_pairs=completion_pairs)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
self.assertEqual(ranks, [0, 1])
20 changes: 20 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"is_pil_available",
"is_wandb_available",
"is_xpu_available",
"is_llmblender_available",
"is_openai_available",
],
"models": [
"AutoModelForCausalLMWithValueHead",
Expand Down Expand Up @@ -55,6 +57,14 @@
"SFTTrainer",
"FDivergenceConstants",
"FDivergenceType",
"WinRateCallback",
"BaseJudge",
"BaseAPIJudge",
"HuggingFaceJudge",
"MockAPIJudge",
"MockJudge",
"OpenAIJudge",
"PairRMJudge",
],
"commands": [],
"commands.cli_utils": ["init_zero_verbose", "SFTScriptArguments", "DPOScriptArguments", "TrlParser"],
Expand Down Expand Up @@ -95,6 +105,8 @@
is_pil_available,
is_wandb_available,
is_xpu_available,
is_llmblender_available,
is_openai_available,
)
from .models import (
AutoModelForCausalLMWithValueHead,
Expand Down Expand Up @@ -126,6 +138,14 @@
SFTTrainer,
FDivergenceConstants,
FDivergenceType,
WinRateCallback,
BaseJudge,
BaseAPIJudge,
HuggingFaceJudge,
MockAPIJudge,
MockJudge,
OpenAIJudge,
PairRMJudge,
)
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
Expand Down
8 changes: 8 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def is_sklearn_available() -> bool:
return find_spec("sklearn") is not None


def is_llmblender_available() -> bool:
return find_spec("llm_blender") is not None


def is_openai_available() -> bool:
return find_spec("openai") is not None


def is_xpu_available() -> bool:
if is_accelerate_greater_20_0():
import accelerate
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@
"sft_trainer": ["SFTTrainer"],
"base": ["BaseTrainer"],
"ddpo_config": ["DDPOConfig"],
"callbacks": ["RichProgressCallback", "SyncRefModelCallback", "WinRateCallback"],
"judges": [
"BaseJudge",
"BaseAPIJudge",
"HuggingFaceJudge",
"MockAPIJudge",
"MockJudge",
"OpenAIJudge",
"PairRMJudge",
],
}

try:
Expand Down Expand Up @@ -95,6 +105,8 @@
from .reward_trainer import RewardTrainer, compute_accuracy
from .sft_config import SFTConfig
from .sft_trainer import SFTTrainer
from .callbacks import RichProgressCallback, SyncRefModelCallback, WinRateCallback
from .judges import BaseJudge, BaseAPIJudge, HuggingFaceJudge, MockAPIJudge, MockJudge, OpenAIJudge, PairRMJudge

try:
if not is_diffusers_available():
Expand Down
103 changes: 99 additions & 4 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import List, Optional, Union

import torch
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import is_deepspeed_available
from accelerate.utils import gather_object, is_deepspeed_available
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from transformers import PreTrainedModel
from transformers.trainer import TrainerCallback
from transformers import (
GenerationConfig,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import has_length

from ..models.utils import unwrap_model_for_generation
from .judges import BaseJudge


if is_deepspeed_available():
import deepspeed
Expand Down Expand Up @@ -138,3 +148,88 @@ def on_train_end(self, args, state, control, **kwargs):
self.rich_console = None
self.training_status = None
self.current_step = None


class WinRateCallback(TrainerCallback):
"""
A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference.

Usage:
```python
trainer = DPOTrainer(...)
win_rate_callback = WinRateCallback(..., trainer=trainer)
trainer.add_callback(win_rate_callback)
```

Args:
prompts (`List[str]`):
The prompts to generate completions for.
judge (`BaseJudge`):
The judge to use for comparing completions.
trainer (`Trainer`):
The trainer.
generation_config (`GenerationConfig`, *optional*):
The generation config to use for generating completions.
batch_size (`int`, *optional*):
The batch size to use for generating completions. Defaults to 4.
"""

def __init__(
self,
prompts: List[str],
judge: BaseJudge,
trainer: Trainer,
generation_config: Optional[GenerationConfig] = None,
batch_size: int = 4,
):
self.prompts = prompts
self.generation_config = generation_config
self.judge = judge
self.ref_completions = []
self.trainer = trainer
self.eval_dataset = self.trainer.eval_dataset
if not hasattr(trainer, "ref_model"):
raise AttributeError("Trainer must have a `ref_model` attribute.")
self.batch_size = batch_size

def generate_completions_for_model(self, model, tokenizer, prompts):
completions = []
with unwrap_model_for_generation(model, self.trainer.accelerator) as unwrapped_model:
unwrapped_model.eval()
for idx in range(0, len(prompts), self.batch_size):
batch = prompts[idx : idx + self.batch_size]
tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device)
generations = unwrapped_model.generate(
**tokenized_batch,
generation_config=self.generation_config,
)
for prompt, generation in zip(tokenized_batch.input_ids, generations):
# Remove prompt from generation
generation = generation[len(prompt) :]
completion = tokenizer.decode(generation, skip_special_tokens=True)
completions.append(completion)

unwrapped_model.train()
return completions

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
tokenizer = kwargs["tokenizer"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
with accelerator.split_between_processes(self.eval_dataset["prompt"], apply_padding=True) as prompts:
self.ref_completions = self.generate_completions_for_model(self.trainer.ref_model, tokenizer, prompts)

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = kwargs["model"]
tokenizer = kwargs["tokenizer"]
accelerator = self.trainer.accelerator
with accelerator.split_between_processes(self.eval_dataset["prompt"], apply_padding=True) as prompts:
completions = self.generate_completions_for_model(model, tokenizer, prompts)
completion_pairs = list(zip(self.ref_completions, completions))
winner_indices = self.judge.judge(self.eval_dataset["prompt"], completion_pairs)
winner_indices = gather_object(winner_indices)

# Logging
if self.trainer.accelerator.is_main_process:
win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices)
self.trainer.log({"eval_win_rate": win_rate})
Loading