diff --git a/README.md b/README.md index 016ec2ff1eb..c0a49208262 100644 --- a/README.md +++ b/README.md @@ -92,16 +92,13 @@ trainer.train() ```python from datasets import load_dataset from trl import GRPOTrainer +from trl.rewards import accuracy_reward -dataset = load_dataset("trl-lib/tldr", split="train") - -# Dummy reward function: count the number of unique characters in the completions -def reward_num_unique_chars(completions, **kwargs): - return [len(set(c)) for c in completions] +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") trainer = GRPOTrainer( model="Qwen/Qwen2-0.5B-Instruct", - reward_funcs=reward_num_unique_chars, + reward_funcs=accuracy_reward, train_dataset=dataset, ) trainer.train() diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index bdc132e4115..92a40b009d2 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -14,10 +14,10 @@ This post-training method was contributed by [Quentin Gallouédec](https://huggi ## Quick start -This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [UltraFeedback prompts dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt). You can view the data in the dataset here: +This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [DeepMath-103K dataset](https://huggingface.co/datasets/trl-lib/DeepMath-103K). You can view the data in the dataset here: