diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index e4a1041033f..4287624c2ac 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -47,6 +47,10 @@ title: ORPO Trainer - local: iterative_sft_trainer title: Iterative Supervised Fine-Tuning + - local: callbacks + title: Callback Classes + - local: judges + title: Judge Classes - local: text_environments title: Text Environments title: API diff --git a/docs/source/callbacks.mdx b/docs/source/callbacks.mdx new file mode 100644 index 00000000000..e4d26797c29 --- /dev/null +++ b/docs/source/callbacks.mdx @@ -0,0 +1,13 @@ +# Callbacks + +## SyncRefModelCallback + +[[autodoc]] SyncRefModelCallback + +## RichProgressCallback + +[[autodoc]] RichProgressCallback + +## WinRateCallback + +[[autodoc]] WinRateCallback diff --git a/docs/source/judges.mdx b/docs/source/judges.mdx new file mode 100644 index 00000000000..23143f9db6d --- /dev/null +++ b/docs/source/judges.mdx @@ -0,0 +1,62 @@ +# Judges + +TRL provides judges to easily compare two completions. + +Make sure to have installed the required dependencies by running: + +```bash +pip install trl[llm_judge] +``` + +## Define your own judge + +To define your own judge, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method that returns a list of 0/1 indicating which completion is better. Here is a dummy example where we define a simple judge that favors longer completions: + +```python +from trl import BaseJudge + +class LengthBasedJudge(BaseJudge): + def judge(self, prompts, completion_pairs, shuffle_order=False): + return [0 if len(c1) > len(c2) else 1 for c1, c2 in completion_pairs] +``` + +You can then use this judge as follows: + +```python +judge = LengthBasedJudge() +judge.judge( + prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], + completion_pairs=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]], +) # Outputs: [1, 0] +``` + +TRL also provides a [`BaseAPIJudge`] class that can be used to define judges that interact with an API. You can subclass [`BaseAPIJudge`] and implement the [`BaseAPIJudge.get_response`] method that should return the response from the API. For an example, see the [`HuggingFaceJudge`] class. + + +## BaseJudge + +[[autodoc]] BaseJudge + +## BaseAPIJudge + +[[autodoc]] BaseAPIJudge + +## HuggingFaceJudge + +[[autodoc]] HuggingFaceJudge + +## MockAPIJudge + +[[autodoc]] MockAPIJudge + +## MockJudge + +[[autodoc]] MockJudge + +## OpenAIJudge + +[[autodoc]] OpenAIJudge + +## PairRMJudge + +[[autodoc]] PairRMJudge diff --git a/setup.py b/setup.py index cdbce4093bd..373cfeea064 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,7 @@ "deepspeed": ["deepspeed>=0.9.5"], "benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"], "quantization": ["bitsandbytes<=0.41.1"], + "llm_judge": ["openai>=1.23.2", "huggingface_hub>=0.22.2", "llm-blender>=0.0.2"], } EXTRAS["dev"] = [] for reqs in EXTRAS.values(): diff --git a/tests/test_judges.py b/tests/test_judges.py new file mode 100644 index 00000000000..d3920c0b894 --- /dev/null +++ b/tests/test_judges.py @@ -0,0 +1,33 @@ +import unittest + +from trl import HuggingFaceJudge, MockAPIJudge, MockJudge + + +class TestJudges(unittest.TestCase): + def _get_prompts_and_completion_pairs(self): + prompts = ["The capital of France is", "The biggest planet in the solar system is"] + completion_pairs = [["Paris", "Marseille"], ["Saturn", "Jupiter"]] + return prompts, completion_pairs + + def test_mock_judge(self): + judge = MockJudge() + prompts, completion_pairs = self._get_prompts_and_completion_pairs() + ranks = judge.judge(prompts=prompts, completion_pairs=completion_pairs) + self.assertEqual(len(ranks), 2) + self.assertTrue(all(isinstance(rank, int) for rank in ranks)) + + def test_mock_api_judge(self): + judge = MockAPIJudge() + prompts, completion_pairs = self._get_prompts_and_completion_pairs() + ranks = judge.judge(prompts=prompts, completion_pairs=completion_pairs) + self.assertEqual(len(ranks), 2) + self.assertTrue(all(isinstance(rank, int) for rank in ranks)) + + @unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.") + def test_hugging_face_judge(self): + judge = HuggingFaceJudge() + prompts, completion_pairs = self._get_prompts_and_completion_pairs() + ranks = judge.judge(prompts=prompts, completion_pairs=completion_pairs) + self.assertEqual(len(ranks), 2) + self.assertTrue(all(isinstance(rank, int) for rank in ranks)) + self.assertEqual(ranks, [0, 1]) diff --git a/trl/__init__.py b/trl/__init__.py index 5c3d3e85ea9..f3b281fcf9d 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -24,6 +24,8 @@ "is_pil_available", "is_wandb_available", "is_xpu_available", + "is_llmblender_available", + "is_openai_available", ], "models": [ "AutoModelForCausalLMWithValueHead", @@ -55,6 +57,14 @@ "SFTTrainer", "FDivergenceConstants", "FDivergenceType", + "WinRateCallback", + "BaseJudge", + "BaseAPIJudge", + "HuggingFaceJudge", + "MockAPIJudge", + "MockJudge", + "OpenAIJudge", + "PairRMJudge", ], "commands": [], "commands.cli_utils": ["init_zero_verbose", "SFTScriptArguments", "DPOScriptArguments", "TrlParser"], @@ -95,6 +105,8 @@ is_pil_available, is_wandb_available, is_xpu_available, + is_llmblender_available, + is_openai_available, ) from .models import ( AutoModelForCausalLMWithValueHead, @@ -126,6 +138,14 @@ SFTTrainer, FDivergenceConstants, FDivergenceType, + WinRateCallback, + BaseJudge, + BaseAPIJudge, + HuggingFaceJudge, + MockAPIJudge, + MockJudge, + OpenAIJudge, + PairRMJudge, ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config diff --git a/trl/import_utils.py b/trl/import_utils.py index df44a38aa81..b0810d12efb 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -101,6 +101,14 @@ def is_sklearn_available() -> bool: return find_spec("sklearn") is not None +def is_llmblender_available() -> bool: + return find_spec("llm_blender") is not None + + +def is_openai_available() -> bool: + return find_spec("openai") is not None + + def is_xpu_available() -> bool: if is_accelerate_greater_20_0(): import accelerate diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 16617f5e276..365cd658b7e 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -49,6 +49,16 @@ "sft_trainer": ["SFTTrainer"], "base": ["BaseTrainer"], "ddpo_config": ["DDPOConfig"], + "callbacks": ["RichProgressCallback", "SyncRefModelCallback", "WinRateCallback"], + "judges": [ + "BaseJudge", + "BaseAPIJudge", + "HuggingFaceJudge", + "MockAPIJudge", + "MockJudge", + "OpenAIJudge", + "PairRMJudge", + ], } try: @@ -95,6 +105,8 @@ from .reward_trainer import RewardTrainer, compute_accuracy from .sft_config import SFTConfig from .sft_trainer import SFTTrainer + from .callbacks import RichProgressCallback, SyncRefModelCallback, WinRateCallback + from .judges import BaseJudge, BaseAPIJudge, HuggingFaceJudge, MockAPIJudge, MockJudge, OpenAIJudge, PairRMJudge try: if not is_diffusers_available(): diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 9052008f067..fbf632a48ee 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -11,20 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import is_deepspeed_available +from accelerate.utils import gather_object, is_deepspeed_available from rich.console import Console, Group from rich.live import Live from rich.panel import Panel from rich.progress import Progress -from transformers import PreTrainedModel -from transformers.trainer import TrainerCallback +from transformers import ( + GenerationConfig, + PreTrainedModel, + Trainer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) from transformers.trainer_utils import has_length +from ..models.utils import unwrap_model_for_generation +from .judges import BaseJudge + if is_deepspeed_available(): import deepspeed @@ -138,3 +148,88 @@ def on_train_end(self, args, state, control, **kwargs): self.rich_console = None self.training_status = None self.current_step = None + + +class WinRateCallback(TrainerCallback): + """ + A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference. + + Usage: + ```python + trainer = DPOTrainer(...) + win_rate_callback = WinRateCallback(..., trainer=trainer) + trainer.add_callback(win_rate_callback) + ``` + + Args: + prompts (`List[str]`): + The prompts to generate completions for. + judge (`BaseJudge`): + The judge to use for comparing completions. + trainer (`Trainer`): + The trainer. + generation_config (`GenerationConfig`, *optional*): + The generation config to use for generating completions. + batch_size (`int`, *optional*): + The batch size to use for generating completions. Defaults to 4. + """ + + def __init__( + self, + prompts: List[str], + judge: BaseJudge, + trainer: Trainer, + generation_config: Optional[GenerationConfig] = None, + batch_size: int = 4, + ): + self.prompts = prompts + self.generation_config = generation_config + self.judge = judge + self.ref_completions = [] + self.trainer = trainer + self.eval_dataset = self.trainer.eval_dataset + if not hasattr(trainer, "ref_model"): + raise AttributeError("Trainer must have a `ref_model` attribute.") + self.batch_size = batch_size + + def generate_completions_for_model(self, model, tokenizer, prompts): + completions = [] + with unwrap_model_for_generation(model, self.trainer.accelerator) as unwrapped_model: + unwrapped_model.eval() + for idx in range(0, len(prompts), self.batch_size): + batch = prompts[idx : idx + self.batch_size] + tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device) + generations = unwrapped_model.generate( + **tokenized_batch, + generation_config=self.generation_config, + ) + for prompt, generation in zip(tokenized_batch.input_ids, generations): + # Remove prompt from generation + generation = generation[len(prompt) :] + completion = tokenizer.decode(generation, skip_special_tokens=True) + completions.append(completion) + + unwrapped_model.train() + return completions + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + tokenizer = kwargs["tokenizer"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + with accelerator.split_between_processes(self.eval_dataset["prompt"], apply_padding=True) as prompts: + self.ref_completions = self.generate_completions_for_model(self.trainer.ref_model, tokenizer, prompts) + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + model = kwargs["model"] + tokenizer = kwargs["tokenizer"] + accelerator = self.trainer.accelerator + with accelerator.split_between_processes(self.eval_dataset["prompt"], apply_padding=True) as prompts: + completions = self.generate_completions_for_model(model, tokenizer, prompts) + completion_pairs = list(zip(self.ref_completions, completions)) + winner_indices = self.judge.judge(self.eval_dataset["prompt"], completion_pairs) + winner_indices = gather_object(winner_indices) + + # Logging + if self.trainer.accelerator.is_main_process: + win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) + self.trainer.log({"eval_win_rate": win_rate}) diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py new file mode 100644 index 00000000000..18a6b1ae2c2 --- /dev/null +++ b/trl/trainer/judges.py @@ -0,0 +1,284 @@ +import logging +import os +import random +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional + +import numpy as np +from accelerate import Accelerator +from huggingface_hub import InferenceClient +from requests import HTTPError + +from ..import_utils import is_llmblender_available, is_openai_available + + +if is_llmblender_available(): + import llm_blender + +if is_openai_available(): + from openai import BadRequestError, OpenAI + + +DEFAULT_SYSTEM_PROMPT = '''I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective. + +## Instruction + +{{ + "instruction": """{prompt}""", +}} + +## Model Outputs + +Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier. + +{{ + {{ + "model_identifier": "0", + "output": """{response1}""" + }}, + {{ + "model_identifier": "1", + "output": """{response2}""" + }} +}} + +## Task + +Evaluate the models based on the quality and relevance of their outputs, and select the model that generated the best output. Answer by providing the model identifier of the best model. We will use your output as the name of the best model, so make sure your output only contains one of the following model identifiers and nothing else (no quotes, no spaces, no new lines, ...): 0 or 1. + +## Best Model Identifier''' + + +class BaseJudge(ABC): + """ + Base class for LLM judges. + + Example: + ```python + class MockJudge(BaseJudge): + def judge(self, prompts, completion_pairs, shuffle_order=True): + return [random.choice([0, 1]) for _ in range(len(prompts))] + + judge = MockJudge() + judge.judge( + prompts=["What is the capital of France?", "What is the capital of Germany?"], + completion_pairs=[["Paris", "Marseille"], ["Munich", "Berlin"]] + ) # [0, 0] + ``` + """ + + @abstractmethod + def judge(self, prompts: List[str], completion_pairs: List[List[str]], shuffle_order: bool = True) -> List[int]: + """ + Judge the completion pairs for the given prompts. + + Args: + prompts (`List[str]`): List of prompts. + completion_pairs (`List[List[str]]`): List of completion pairs, where each pair is a list of two strings. + shuffle_order (`bool`): Whether to shuffle the order of the completion pairs, to avoid positional bias. + + Returns: + List of integers, where each integer is the index of the completion pair that is preferred. + """ + raise NotImplementedError("Judge subclasses must implement this method.") + + +class BaseAPIJudge(BaseJudge): + """ + Base class for LLM judges reached via an API. + + The subclasses of this class should implement the `get_response` method to interact with the API. + + Args: + system_prompt (`str`, *optional*): The system prompt to be used for the judge. If not provided, a default prompt is used. + max_tries (`int`, *optional*): The maximum number of retries for a request. Defaults to 5. + max_workers (`int`, *optional*): The maximum number of parallel requests. Defaults to 8. + + Example: + ```python + class MockAPIJudge(BaseAPIJudge): + def get_response(self, content): + return random.choice(["0", "1"]) + + judge = MockAPIJudge() + judge.judge( + prompts=["What is the capital of France?", "What is the capital of Germany?"], + completion_pairs=[["Paris", "Marseille"], ["Munich", "Berlin"]] + ) # [1, 1] + ``` + """ + + # TODO: add max_requests parameter to limit the number of requests made + def __init__(self, system_prompt: Optional[str] = None, max_tries: int = 5, max_workers: int = 8): + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + self.system_prompt = system_prompt + self.max_tries = max_tries + self.thread_pool_executor = ThreadPoolExecutor(max_workers=max_workers) + + def __del__(self) -> None: + self.thread_pool_executor.shutdown() + + @abstractmethod + def get_response(self, content: str) -> str: + """ + Get the response from the API for the given content. + + Args: + content (`str`): The string content. + + Returns: + The response from the API as a string. + """ + + raise NotImplementedError("Judge subclasses must implement this method.") + + def judge_single(self, prompt: str, completion_pair: List[str], shuffle_order: bool = True) -> int: + flipped = random.choice([True, False]) if shuffle_order else False + completion_pair = completion_pair[::-1] if flipped else completion_pair + + retry = 0 + while retry < self.max_tries: + content = self.system_prompt.format( + prompt=prompt, response1=completion_pair[0], response2=completion_pair[1] + ) + reply = self.get_response(content) + reply = reply.strip() + + if reply in ["0"]: + return 0 if not flipped else 1 + elif reply in ["1"]: + return 1 if not flipped else 0 + else: + logging.info(f"Judge gave response `{reply}` instead of the expected 0 or 1. Retrying.") + retry += 1 + + logging.info( + f"Max retries reached for prompt:\n\n{prompt}\nand completion pair:\n\n{completion_pair}\n\nReturning random choice." + ) + return random.choice([0, 1]) + + def judge(self, prompts: List[str], completion_pairs: List[List[str]], shuffle_order: bool = True) -> List[int]: + futures = [] + for prompt, completion_pair in zip(prompts, completion_pairs): + future = self.thread_pool_executor.submit(self.judge_single, prompt, completion_pair, shuffle_order) + futures.append(future) + + return [f.result() for f in futures] + + +class PairRMJudge(BaseJudge): + """ + LLM judge based on the PairRM model from AllenAI. + + See: https://huggingface.co/llm-blender/PairRM + """ + + def __init__(self): + if not is_llmblender_available(): + raise ValueError("llm-blender is not installed. Please install it with 'pip install llm-blender'.") + self.blender = llm_blender.Blender() + self.blender.loadranker("llm-blender/PairRM", device=Accelerator().device) + + def judge(self, prompts: List[str], completion_pairs: List[List[str]], shuffle_order: bool = True) -> List[int]: + if shuffle_order: + flip_mask = np.random.choice([True, False], size=len(prompts)) + completion_pairs = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completion_pairs)] + ranks = self.blender.rank(prompts, completion_pairs) + ranks -= 1 # PairRM is 1-indexed, so we subtract 1 to make it 0-indexed + if shuffle_order: + # Flip back the ranks to the original order + ranks[flip_mask] = ranks[flip_mask][:, ::-1] + return ranks[:, 0].tolist() + + +class MockJudge(BaseJudge): + """ + Mock judge that randomly selects a model for each completion pair. + """ + + def judge(self, prompts: List[str], completion_pairs: List[List[str]]) -> List[int]: + return [random.choice([0, 1]) for _ in range(len(prompts))] + + +class MockAPIJudge(BaseAPIJudge): + """ + Mock judge that returns a random choice instead of interacting with an API. + """ + + def get_response(self, content: str) -> str: + return random.choice(["0", "1"]) + + +class HuggingFaceJudge(BaseAPIJudge): + """ + Judge based on the Hugging Face API. + + Args: + model (`str`, *optional*): The model to use for the judge. Defaults to "meta-llama/Meta-Llama-3-70B-Instruct". + system_prompt (`str`, *optional*): The system prompt to be used for the judge. If not provided, a default prompt is used. + max_tries (`int`, *optional*): The maximum number of retries for a request. Defaults to 5. + max_workers (`int`, *optional*): The maximum number of parallel requests. Defaults to 8. + token (`str`, *optional*): The Hugging Face API token to use for the InferenceClient. + """ + + def __init__( + self, + model="meta-llama/Meta-Llama-3-70B-Instruct", + system_prompt: Optional[str] = None, + max_tries: int = 5, + max_workers: int = 8, + token: Optional[str] = None, + ): + super().__init__(system_prompt=system_prompt, max_tries=max_tries, max_workers=max_workers) + self.client = InferenceClient(model=model, token=token) + + def get_response(self, content: str) -> str: + try: + response = self.client.chat_completion( + messages=[{"role": "user", "content": content}], + max_tokens=1, + stop=["<|eot_id|>"], # For llama-3 models + ) + return response.choices[0].message.content + except HTTPError as e: + logging.info(f"Unable to reach the Hugging Face API due to error: {e}\nReturning random choice (0,1)") + return random.choice(["0", "1"]) + + +class OpenAIJudge(BaseAPIJudge): + """ + Judge based on the OpenAI API. + + Args: + model (`str`, *optional*): The model to use for the judge. Defaults to "gpt-4-turbo-preview". + system_prompt (`str`, *optional*): The system prompt to be used for the judge. If not provided, a default prompt is used. + max_tries (`int`, *optional*): The maximum number of retries for a request. Defaults to 5. + max_workers (`int`, *optional*): The maximum number of parallel requests. Defaults to 8. + """ + + def __init__( + self, + model="gpt-4-turbo-preview", + system_prompt: Optional[str] = None, + max_tries: int = 5, + max_workers: int = 8, + ): + if not is_openai_available(): + raise ValueError("OpenAI client is not installed. Please install it with 'pip install openai'.") + super().__init__(system_prompt=system_prompt, max_tries=max_tries, max_workers=max_workers) + self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + self.model = model + + def get_response(self, content: str) -> str: + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": content}], + max_tokens=1, # TODO: let users configure these variables + ) + return response.choices[0].message.content + except BadRequestError as e: + logging.warn(f"Unable to reach to OpenAI API due to error: {e}\nReturning random choice (0, 1)") + return random.choice(["0", "1"])