Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional r1-style thinking reward #551

Merged
merged 10 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/algorithms/grpo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Grouped Relative Policy Optimization (GRPO)

GRPO is an online RL method used in [DeepSeek R1 paper](https://arxiv.org/abs/2501.12948) and its first appearance is in [DeepSeekMath](https://arxiv.org/abs/2402.03300)

`open_instruct/grpo_vllm_thread_ray_gtrl.py` contains an implementation of GRPO.


## Get started


Here is a command to run GRPO on the Llama3.1-8b on [ai2-adapt-dev/rlvr_gsm8k_zs](https://huggingface.co/datasets/ai2-adapt-dev/rlvr_gsm8k_zs), which is simply a zero-shot version of the RLVR GSM8K dataset.


```bash
bash scripts/train/rlvr/grpo_llama3.1-8b.sh
```

The results look quite reasonable: with format score, score all going up, KL not exploding, and sequence length seems stable (at least at first)


![alt text](grpo_8b.png)
18 changes: 18 additions & 0 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer):
"{% endif %}"
"{% endfor %}"
),
"r1_simple_chat_postpend_think": (
"A conversation between User and Assistant. "
"The user asks a question, and the Assistant solves it. "
"The assistant first thinks about the reasoning process in "
"the mind and then provides the user with the answer. "
"The reasoning process and answer are enclosed within <think> </think> "
"and <answer> </answer> tags, respectively, "
"i.e., <think> reasoning process here </think> "
"<answer> answer here </answer>."
"\n\n"
"{% for message in messages %}"
"{{ '\n\n' if not loop.first else '' }}"
"{{ message['role'].capitalize() + ': ' + message['content'] + '\n' }}"
"{% if loop.last and add_generation_prompt %}"
"{{ 'Assistant: <think>' }}"
"{% endif %}"
"{% endfor %}"
),
}
# flake8: noqa

Expand Down
29 changes: 23 additions & 6 deletions open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
'''
"""
Collection of 'ground truth rewards' for different datasets/tasks.
Used to give feedback to the model based on the ground truth answer.
'''
import re
"""

import json
import re
import string
from open_instruct.math_utils import last_boxed_only_string, remove_boxed, get_unnormalized_answer, normalize_final_answer, is_equiv, hendrycks_is_equiv

from open_instruct.if_functions import IF_FUNCTIONS_MAP
from open_instruct.math_utils import (
get_unnormalized_answer,
hendrycks_is_equiv,
is_equiv,
last_boxed_only_string,
normalize_final_answer,
remove_boxed,
)


def verify_gsm8k_sample(model_output, ground_truth_answer):
Expand Down Expand Up @@ -138,11 +147,19 @@ def verify_flan_sample(model_output, ground_truth_answer):
return normalize_answer(answer_string) == normalize_answer(ground_truth_answer)


def soft_format_reward_func(responses: list[str], reward_scale: float = 1.0) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r".*?</think>\s*<answer>.*?</answer>"
matches = [re.match(pattern, r, re.DOTALL) for r in responses]
return [reward_scale if match else 0.0 for match in matches]


# debug code
if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vwxyzjn lets add this to the tests check we have?
Also, let's make the scale of the reward set by a hyperparam / config? Could get tricky reward shaping issues.

from datasets import load_dataset

ds = load_dataset("ai2-adapt-dev/prompts_with_constraints_for_ground_truth")
test_model_output = "<|assistant|>\nThe answer is $\\boxed{3.14}$"
for sample in ds['train']:
for sample in ds["train"]:
print(sample)
verify_ifeval_sample(test_model_output, sample['ground_truth'])
verify_ifeval_sample(test_model_output, sample["ground_truth"])
24 changes: 23 additions & 1 deletion open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TokenizerConfig,
get_cached_dataset_rlvr,
)
from open_instruct.ground_truth_utils import soft_format_reward_func

os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA

Expand Down Expand Up @@ -234,6 +235,10 @@ class Args:
"""the reward model multiplier, for down/upscaling the reward model output"""
verification_reward: float = 10.0
"""the reward value for verifiable responses"""
add_r1_style_format_reward: bool = False
"""whether to add the R1 style format reward"""
r1_style_format_reward: float = 1.0
"""the reward value for R1 style format reward"""

# async setting
async_mode: bool = True
Expand Down Expand Up @@ -1058,7 +1063,9 @@ def vllm_generate(
sequence_lengths = []
if accelerator.is_main_process:
g_response_token_ids = response_ids_Q.get()
DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out
DUMMY_PAD_TOKEN = (
args.stop_token_id
) # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out
g_padded_response_ids = [
response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response))
for response in g_response_token_ids
Expand All @@ -1071,6 +1078,12 @@ def vllm_generate(
]
# print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}")
query_responses = torch.cat((queries, local_vllm_responses), 1)

if args.add_r1_style_format_reward:
decoded_response = tokenizer.batch_decode(local_vllm_responses)
format_scores = torch.tensor(
soft_format_reward_func(decoded_response, args.r1_style_format_reward), device=device
)
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
# print(f"get reward stuff starts {i=}")
query = queries[i : i + args.local_rollout_forward_batch_size]
Expand Down Expand Up @@ -1122,6 +1135,9 @@ def vllm_generate(
else:
verifiable_count = torch.tensor([0.0], device=device).float()

if args.add_r1_style_format_reward:
score += format_scores[i : i + args.local_rollout_forward_batch_size]

responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
Expand All @@ -1140,6 +1156,10 @@ def vllm_generate(
verifiable_counts = torch.cat(verifiable_counts, 0)
verifiable_correct_rate = verifiable_counts.sum() / queries.shape[0]
# print(f"get reward stuff finished")
if self.rank == 0:
print(f"{sequence_lengths=}")
print(f"{postprocessed_responses[0]=}")
print(f"{tokenizer.decode(postprocessed_responses[0])=}")
del (logprob, ref_logprob, score)
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1294,6 +1314,8 @@ def vllm_generate(
local_metrics.add("val/ratio", ratio_stats.mean())
local_metrics.add("val/ratio_var", ratio_stats.var())
local_metrics.add("val/stop_token_rate", contain_stop_token.float().mean())
if args.add_r1_style_format_reward:
local_metrics.add("val/format_scores", format_scores.float().mean())

metrics = {
"episode": episode,
Expand Down
40 changes: 40 additions & 0 deletions scripts/train/rlvr/grpo_llama3.1-8b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
python open_instruct/grpo_vllm_thread_ray_gtrl.py \
--exp_name $exp_name \
--output_dir /weka/oe-adapt-default/costah/models/$exp_name \
--dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 1.0 \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 1.0 \
--dataset_mixer_eval_list_splits train \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
--number_samples_per_prompt 4 \
--model_name_or_path meta-llama/Llama-3.1-8B \
--stop_strings '"</answer>"' \
--add_r1_style_format_reward \
--non_stop_penalty False \
--stop_token eos \
--penalty_reward_value 0.0 \
--temperature 0.7 \
--ground_truths_key ground_truth \
--chat_template_name r1_simple_chat_postpend_think \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 1000000 \
--deepspeed_stage 3 \
--per_device_train_batch_size 1 \
--local_rollout_forward_batch_size 1 \
--local_mini_batch_size 16 \
--local_rollout_batch_size 16 \
--num_epochs 1 \
--actor_num_gpus_per_node 6 \
--vllm_tensor_parallel_size 2 \
--beta 0.01 \
--apply_verifiable_reward true \
--seed 3 \
--num_evals 100 \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--no_try_launch_beaker_eval_jobs \
--gradient_checkpointing \
--with_tracking