Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO implementation update #534

Open
vwxyzjn opened this issue Jan 29, 2025 · 17 comments
Open

GRPO implementation update #534

vwxyzjn opened this issue Jan 29, 2025 · 17 comments

Comments

@vwxyzjn
Copy link
Collaborator

vwxyzjn commented Jan 29, 2025

Let's use this issue to share the latest GRPO development updates. CC @gauravpandeyamu thanks for your fix.

The command below (without @gauravpandeyamu's fix) yields the charts below. Overall the training score and sequence length goes up, but the downstream eval in MATH seems to suffer. I am gonna try out @gauravpandeyamu's fix on KL regularization to see if it helps.

for beta in 0.03 0.01 0.0; do
for nspp in 16 32; do
for m in half-m ; do
local_rollout_batch_size=8
if [ $m == "half-m" ]; then
    local_mini_batch_size=$(($local_rollout_batch_size * $nspp / 2))
else
    local_mini_batch_size=$(($local_rollout_batch_size * $nspp))
fi
exp_name="0128_grpo_math_zs_${beta}_${nspp}_${m}_${RANDOM}"
full_bsz=$(($local_rollout_batch_size * nspp * (8 + 8 + 8 + 7) * 2))
echo $exp_name:
echo --- local_mini_batch_size=$local_mini_batch_size
echo --- full_bsz=$full_bsz
echo --- num_gradient_updates=$(($local_rollout_batch_size * $nspp / $local_mini_batch_size))
exp_name="0128_grpo_math_zs_${beta}_${nspp}_${RANDOM}"
python mason.py \
    --cluster ai2/jupiter-cirrascale-2 \
    --workspace ai2/tulu-3-dev \
    --priority high \
    --preemptible \
    --num_nodes 1 \
    --max_retries 1 \
    --budget ai2/oe-adapt \
    --gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& uv run python open_instruct/grpo_vllm_thread_ray_gtrl.py \
    --exp_name $exp_name \
    --beta $beta \
    --local_mini_batch_size $local_mini_batch_size \
    --number_samples_per_prompt $nspp \
    --output_dir /weka/oe-adapt-default/costah/models/$exp_name \
    --local_rollout_batch_size $local_rollout_batch_size \
    --dataset_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 1.0}" \
    --dataset_train_splits train \
    --dataset_eval_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 32}" \
    --dataset_eval_splits train \
    --max_token_length 2048 \
    --max_prompt_token_length 2048 \
    --response_length 2048 \
    --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \
    --non_stop_penalty \
    --stop_token eos \
    --temperature 1.0 \
    --ground_truths_key ground_truth \
    --chat_template tulu \
    --sft_messages_key messages \
    --learning_rate 5e-7 \
    --total_episodes 1000000 \
    --penalty_reward_value 0.0 \
    --deepspeed_stage 2 \
    --per_device_train_batch_size 2 \
    --local_rollout_forward_batch_size 2 \
    --actor_num_gpus_per_node 6 \
    --num_epochs 1 \
    --vllm_tensor_parallel_size 2 \
    --lr_scheduler_type constant \
    --apply_verifiable_reward true \
    --seed 1 \
    --num_evals 1000 \
    --save_freq 40 \
    --reward_model_multiplier 0.0 \
    --no_try_launch_beaker_eval_jobs \
    --try_launch_beaker_eval_jobs_on_weka \
    --gradient_checkpointing \
    --with_tracking
done
done
done
Image Image
@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Jan 30, 2025

The KL constraint didn't work well. 0129_grpo_math_kl_fix_zs_0.03_16_half-m_30414__1__1738185952 which includes the KL fix seems to have even higher KL than the previously incorrect run.

Image

@gauravpandeyamu
Copy link
Contributor

gauravpandeyamu commented Jan 30, 2025

That's a good catch. While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence. I suspect that you are using the kl1 estimator.

The gradient of the kl divergence between the policy $\pi_\theta$ and $\pi_{ref}$, the reference policy, is given by
$$\nabla_\theta \sum_{y} \pi_\theta (y|x) \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}$$

If the samples $y_1, ..., y_n$ come from $\pi_t(y|x)$, the gradient can be computed as
$$\nabla_\theta \frac{1}{n} \sum_{i=1}^n \frac{\pi_\theta (y|x)}{\pi_t(y|x)} \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}$$

This is what kl1 should have been.
kl2 and kl3 are still fine.

BTW, the GRPO paper recommends using kl3 estimator (Equation 4 in https://arxiv.org/pdf/2402.03300)

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Jan 30, 2025

Yeah, so I thought about the kl3. The issue is that kl3 blows up: see

huggingface/trl#423 (comment)

but this only happens when using KL as part of the reward. I am not sure what happens when we put it in the loss directly...

@gauravpandeyamu
Copy link
Contributor

Interesting. I agree that it might work if added in the loss directly.

I have modified my local fork with the following changes:

                    # kl loss should be computed without torch.no_grad()
                    ref_logprobs_diff = new_logprobs - ref_logprobs[micro_batch_inds]
                    kl1 = ratio * ref_logprobs_diff
                    kl2 = (ref_logprobs_diff) ** 2 / 2
                    kl3 = (-ref_logprobs_diff).exp() - 1 + ref_logprobs_diff
                    if args.kl_estimator == "kl1":
                        kl = kl1
                    elif args.kl_estimator == "kl2":
                        kl = kl2
                    elif args.kl_estimator == "kl3":
                        kl = kl3

                    kl_loss = masked_mean(kl, ~padding_mask[micro_batch_inds])
                    pg_loss = pg_loss + args.beta * kl_loss

Basically, it gives a kl1 estimator whose gradient is unbiased. I haven't created a new PR since it has conflicting changes with your PR.

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Jan 30, 2025

I can help test it out. I modified it to the following and launched a scan. Will report back with results.

                        kl1 = new_logprobs - mb_reflogprobs
                        kl2 = (kl1) ** 2 / 2
                        kl3 = (-kl1).exp() - 1 + kl1
                        kl4 = ratio * kl1
                        if args.kl_estimator == "kl1":
                            kl = kl1
                        elif args.kl_estimator == "kl2":
                            kl = kl2
                        elif args.kl_estimator == "kl3":
                            kl = kl3
                        elif args.kl_estimator == "kl4":
                            kl = kl4

                        if epoch_idx == 0:
                            kl_stats[micro_batch_inds] = kl.sum(1).float()
                        
                        # grpo change: directly subtract KL in loss (add)
                        pg_loss = masked_mean(pg_loss_max + (args.beta * kl), ~padding_mask[micro_batch_inds])
for beta in 0.03; do
for nspp in 16; do
for m in half-m ; do
for kl_estimator in k1 k2 k3 k4; do
local_rollout_batch_size=8
if [ $m == "half-m" ]; then
    local_mini_batch_size=$(($local_rollout_batch_size * $nspp / 2))
else
    local_mini_batch_size=$(($local_rollout_batch_size * $nspp))
fi
exp_name="0130_kl_scan_grpo_math_zs_${kl_estimator}_${beta}_${nspp}_${m}_${RANDOM}"
full_bsz=$(($local_rollout_batch_size * nspp * (8 + 8 + 8 + 7) * 2))
echo $exp_name:
echo --- local_mini_batch_size=$local_mini_batch_size
echo --- full_bsz=$full_bsz
echo --- num_gradient_updates=$(($local_rollout_batch_size * $nspp / $local_mini_batch_size))
python mason.py \
    --cluster ai2/jupiter-cirrascale-2 \
    --workspace ai2/tulu-3-dev \
    --priority high \
    --preemptible \
    --num_nodes 1 \
    --max_retries 1 \
    --budget ai2/oe-adapt \
    --gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& uv run python open_instruct/grpo_vllm_thread_ray_gtrl.py \
    --exp_name $exp_name \
    --beta $beta \
    --local_mini_batch_size $local_mini_batch_size \
    --number_samples_per_prompt $nspp \
    --output_dir /weka/oe-adapt-default/costah/models/$exp_name \
    --local_rollout_batch_size $local_rollout_batch_size \
    --dataset_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 1.0}" \
    --dataset_train_splits train \
    --dataset_eval_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 32}" \
    --dataset_eval_splits train \
    --max_token_length 2048 \
    --max_prompt_token_length 2048 \
    --response_length 2048 \
    --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \
    --non_stop_penalty \
    --stop_token eos \
    --temperature 1.0 \
    --ground_truths_key ground_truth \
    --chat_template tulu \
    --sft_messages_key messages \
    --learning_rate 5e-7 \
    --total_episodes 1000000 \
    --penalty_reward_value 0.0 \
    --deepspeed_stage 2 \
    --per_device_train_batch_size 2 \
    --local_rollout_forward_batch_size 2 \
    --actor_num_gpus_per_node 6 \
    --num_epochs 1 \
    --vllm_tensor_parallel_size 2 \
    --lr_scheduler_type constant \
    --apply_verifiable_reward true \
    --seed 1 \
    --num_evals 1000 \
    --save_freq 40 \
    --reward_model_multiplier 0.0 \
    --no_try_launch_beaker_eval_jobs \
    --try_launch_beaker_eval_jobs_on_weka \
    --gradient_checkpointing \
    --with_tracking
done
done
done
done

@gauravpandeyamu
Copy link
Contributor

Great. Thanks

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Jan 30, 2025

Eh this is a bit awkward -- none of the KL estimator seem to control KL well.

Image

@gauravpandeyamu
Copy link
Contributor

Since RLVR looks only at the final answer completely ignoring the the text generated, the model can be motivated to generate text that deviates from the reference policy as long as it reaches the correct final answer.

Perhaps, a stronger KL term (higher beta value) is desirable in RLVR, thus ensuring that the text generated is sensible.

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Jan 31, 2025

FYI on KL3 estimator. https://x.com/vwxyzjn/status/1885329398821187633

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Jan 31, 2025

Ok so when I launched the experiments I accidentally left out the --kl_estimator 🤡, so all exps were run using kl1.

Now when using the kl3 it looks much more reasonable.

Image

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Jan 31, 2025

Meanwhile

  1. kl1 looks wrong (larger beta induces larger KL...???)
Image
  1. kl2 also looks reasonable
Image
  1. kl4 seems ok
Image

@gauravpandeyamu why did you multiply the ref_logprobs_diff by ratio? I don't get it.

@gauravpandeyamu
Copy link
Contributor

Ahh, yes. Now, the graphs of kl1, kl2, kl3 and kl4 make perfect sense.

As for why kl1 estimator is a bad estimator, and how multiplying by the ratio fixes it (mainly the bias is fixed), here is ChatGPT's response.

https://chatgpt.com/share/679dfb75-cdec-800f-9078-f838d3925f9e

@tristandeleu
Copy link

While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence.

If I am not mistaken, I believe that the gradient of KL3 is also not an unbiased estimator of the gradient of the KL divergence.

$$\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta}\left(\frac{\pi_{\mathrm{ref}}}{\pi_{\theta}} - \log \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}} - 1\right)\right] = -\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta}\log \pi_{\theta} \times \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}}\right] \neq -\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta} \log \pi_{\theta} \times \log \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}}\right] = \nabla_{\theta}\mathrm{KL}(\pi_{\theta}\|\pi_{\mathrm{ref}})$$

Given this, I don't fully understand how KL3 can work for optimization (maybe the lower variance alone is good enough, despite this potential bias). It should be a good estimator to monitor the KL divergence (notably because it remains non-negative, is unbiased and has low variance), but using it for optimization is a bit mysterious to me.

I'm sorry to highjack this issue, but I haven't seen a lot of discussions on the use of KL3 in GRPO.

@gauravpandeyamu
Copy link
Contributor

While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence.

If I am not mistaken, I believe that the gradient of KL3 is also not an unbiased estimator of the gradient of the KL divergence.

E π θ [ ∇ θ ( π ref π θ − log ⁡ π ref π θ − 1 ) ] = − E π θ [ ∇ θ log ⁡ π θ × π ref π θ ] ≠ − E π θ [ ∇ θ log ⁡ π θ × log ⁡ π ref π θ ] = ∇ θ KL ( π θ ∥ π ref )

Given this, I don't fully understand how KL3 can work for optimization (maybe the lower variance alone is good enough, despite this potential bias). It should be a good estimator to monitor the KL divergence (notably because it remains non-negative, is unbiased and has low variance), but using it for optimization is a bit mysterious to me.

I'm sorry to highjack this issue, but I haven't seen a lot of discussions on the use of KL3 in GRPO.

You are right. It is biased but with a low variance. Intuitively, another reason why it works is that each term of the kl3 estimator (-log p(x)/q(x) + p(x)/q(x) -1 ) is lower bounded by 0 (same as kl2) with the equality achieved only when p(x)=q(x). So, if you try to estimate kl with just one sample and optimize the kl3 estimator for that sample, you will end up making the two probs equal for that sample.

On the other hand, if you try to minimize the kl1 estimator for a single sample (-log p(x)/q(x)), you will just end up maximizing q(x). This is why optimizing kl1 leads to blow-up in KL if beta is higher.

@gauravpandeyamu
Copy link
Contributor

gauravpandeyamu commented Feb 4, 2025

Also worth noting that $E_{\pi_t} \left[\frac{\pi_\theta}{\pi_{ref}} -\log \frac{\pi_\theta}{\pi_{ref}} - 1\right]$ is a valid divergence and the kl3 estimator and its gradient are unbiased estimators of this divergence and its gradient respectively.

@tristandeleu
Copy link

That makes sense. In that case, I would find KL2 a more "principled" choice for optimization, since its gradient wrt. $\theta$ is an unbiased estimator of the gradient of the KL divergence (provided that it still has low variance, which seems to be validated by @vwxyzjn's experiments). But KL2 would still be a biased estimator of the KL divergence (to report metrics), so maybe optimizing KL2 (kl = kl2) and monitoring KL3 is the way to go?

But your comment on it being a valid divergence also makes a lot of sense, and maybe what the success of KL3 shows is that we should go beyond just the KL divergence.

@gauravpandeyamu
Copy link
Contributor

I agree with KL2 being a more principled choice for optimization.
There are works that explore f-divergences in the PPO objective https://arxiv.org/pdf/2309.16240

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants