-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[Doc] Drop dummy reward and dataset for DeepMath-103K and accuracy reward #4524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
9e3d1f5
Drop dummy rewards and dataset for more infortive examples
qgallouedec 4d37f4e
style
qgallouedec f9224b9
Merge branch 'main' into real-example-for-grpo
qgallouedec 04f2b4d
Update dataset references in GRPO and RLOO trainer documentation to D…
qgallouedec 6872ed4
Merge branch 'main' into real-example-for-grpo
qgallouedec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,10 +14,10 @@ This post-training method was contributed by [Quentin Gallouédec](https://huggi | |
|
|
||
| ## Quick start | ||
|
|
||
| This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [UltraFeedback prompts dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt). You can view the data in the dataset here: | ||
| This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [DeepMath-103K dataset](https://huggingface.co/datasets/trl-lib/DeepMath-103K). You can view the data in the dataset here: | ||
|
|
||
| <iframe | ||
| src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0" | ||
| src="https://huggingface.co/datasets/trl-lib/DeepMath-103K/embed/viewer/default/train?row=0" | ||
| frameborder="0" | ||
| width="100%" | ||
| height="560px" | ||
|
|
@@ -28,21 +28,14 @@ Below is the script to train the model. | |
| ```python | ||
| # train_grpo.py | ||
| from datasets import load_dataset | ||
| from trl import GRPOConfig, GRPOTrainer | ||
|
|
||
| dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") | ||
| from trl import GRPOTrainer | ||
| from trl.rewards import accuracy_reward | ||
|
|
||
| # Dummy reward function for demonstration purposes | ||
| def reward_num_unique_letters(completions, **kwargs): | ||
| """Reward function that rewards completions with more unique letters.""" | ||
| completion_contents = [completion[0]["content"] for completion in completions] | ||
| return [float(len(set(content))) for content in completion_contents] | ||
| dataset = load_dataset("trl-lib/DeepMath-103K", split="train") | ||
|
|
||
| training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO") | ||
| trainer = GRPOTrainer( | ||
| model="Qwen/Qwen2-0.5B-Instruct", | ||
| reward_funcs=reward_num_unique_letters, | ||
| args=training_args, | ||
| reward_funcs=accuracy_reward, | ||
| train_dataset=dataset, | ||
| ) | ||
| trainer.train() | ||
|
|
@@ -290,29 +283,27 @@ import argparse | |
|
|
||
| from datasets import load_dataset | ||
| from trl import GRPOTrainer, GRPOConfig | ||
| from trl.rewards import accuracy_reward | ||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP") | ||
| args = parser.parse_args() | ||
|
|
||
| # Example dataset from TLDR | ||
| dataset = load_dataset("trl-lib/tldr", split="train") | ||
|
|
||
| # Dummy reward function: count the number of unique characters in the completions | ||
| def reward_num_unique_chars(completions, **kwargs): | ||
| return [len(set(c)) for c in completions] | ||
| dataset = load_dataset("trl-lib/DeepMath-103K", split="train") | ||
|
|
||
| training_args = GRPOConfig( | ||
| output_dir="Qwen2.5-72B-GRPO", | ||
| per_device_train_batch_size=4, | ||
| bf16=True, | ||
| gradient_checkpointing=True, | ||
|
Comment on lines
-309
to
-310
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are default values |
||
| use_vllm=True, | ||
| vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X | ||
| ) | ||
|
|
||
| trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset) | ||
| trainer = GRPOTrainer( | ||
| model="Qwen/Qwen2.5-72B", | ||
| args=training_args, | ||
| reward_funcs=accuracy_reward, | ||
| train_dataset=dataset | ||
| ) | ||
| trainer.train() | ||
|
|
||
| if __name__=="__main__": | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generated automatically, no need to pass it explicitly