Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does the loss start at 0 when I train GRPO, and then possibly increase? #239

Closed
hellen9527 opened this issue Feb 8, 2025 · 14 comments
Closed

Comments

@hellen9527
Copy link

hellen9527 commented Feb 8, 2025

I am using the distill-1.5b model, and since I only have 4 L20 GPUs, I modified some parameters and am still training the GRPO model on the NuminaMath-TIR dataset. However, I noticed that the loss remains 0, and I'm not sure where the configuration went wrong. I have ensured that the software versions match those in the setup.py file, and I also updated TRL and transformers to the latest version of the main branch. The specific logs and training configuration are as follows. I would like to know if this is normal and how to fix it.

train config:

# Model arguments
model_name_or_path: /home/base-model/deepseek-r1-distill-qwen-1.5b
model_revision: main
torch_dtype: bfloat16

# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 3

# GRPO trainer config
gradient_accumulation_steps: 2
num_generations: 3

train log

[INFO|trainer.py:2348] 2025-02-08 12:02:29,782 >> ***** Running training *****
[INFO|trainer.py:2349] 2025-02-08 12:02:29,782 >>   Num examples = 72,441
[INFO|trainer.py:2350] 2025-02-08 12:02:29,782 >>   Num Epochs = 1
[INFO|trainer.py:2351] 2025-02-08 12:02:29,782 >>   Instantaneous batch size per device = 1
[INFO|trainer.py:2354] 2025-02-08 12:02:29,782 >>   Total train batch size (w. parallel, distributed & accumulation) = 6
[INFO|trainer.py:2355] 2025-02-08 12:02:29,782 >>   Gradient Accumulation steps = 2
[INFO|trainer.py:2356] 2025-02-08 12:02:29,782 >>   Total optimization steps = 36,220
[INFO|trainer.py:2357] 2025-02-08 12:02:29,783 >>   Number of trainable parameters = 1,777,088,000
{'loss': 0.0, 'grad_norm': 0.72000175680703, 'learning_rate': 2.760905577029266e-08, 'rewards/accuracy_reward': 0.26666667461395266, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.6777778208255768, 'rewards/cosine_scaled_reward': -0.022902203630656003, 'reward': 0.921542277932167, 'reward_std': 0.871876309812069, 'completion_length': 876.4000122070313, 'kl': 0.00035610198974609373, 'epoch': 0.0}
{'loss': 0.0, 'grad_norm': 0.8210723493263515, 'learning_rate': 5.521811154058532e-08, 'rewards/accuracy_reward': 0.10000000298023223, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.6333333641290665, 'rewards/cosine_scaled_reward': -0.23128306418657302, 'reward': 0.5020502872765065, 'reward_std': 0.43509662076830863, 'completion_length': 884.033349609375, 'kl': 0.0006114959716796875, 'epoch': 0.0}
{'loss': 0.0, 'grad_norm': 0.6075981772711617, 'learning_rate': 8.282716731087798e-08, 'rewards/accuracy_reward': 0.1666666716337204, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.5555555850267411, 'rewards/cosine_scaled_reward': -0.16871370139997452, 'reward': 0.5535085469484329, 'reward_std': 0.6925141368061304, 'completion_length': 886.1666809082031, 'kl': 0.0005586624145507812, 'epoch': 0.0}
{'loss': 0.0, 'grad_norm': 0.7033610775329348, 'learning_rate': 1.1043622308117064e-07, 'rewards/accuracy_reward': 0.1666666716337204, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.6888889163732529, 'rewards/cosine_scaled_reward': -0.17193117612041534, 'reward': 0.6836243975907564, 'reward_std': 0.7369554199278354, 'completion_length': 892.0000122070312, 'kl': 0.00048828125, 'epoch': 0.0}
...
{'loss': 0.0001, 'grad_norm': 0.6114522070289464, 'learning_rate': 1.049144119271121e-06, 'rewards/accuracy_reward': 0.3000000089406967, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.7333333641290665, 'rewards/cosine_scaled_reward': -0.05265774726867676, 'reward': 0.9806756511330604, 'reward_std': 0.8146779596805572, 'completion_length': 926.8666748046875, 'kl': 0.001399993896484375, 'epoch': 0.01}
{'loss': 0.0001, 'grad_norm': 0.6375849273871735, 'learning_rate': 1.0767531750414136e-06, 'rewards/accuracy_reward': 0.1666666716337204, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.7111111462116242, 'rewards/cosine_scaled_reward': -0.14114616215229034, 'reward': 0.736631666123867, 'reward_std': 0.7692775622010231, 'completion_length': 937.2000122070312, 'kl': 0.001470184326171875, 'epoch': 0.01}
{'loss': 0.0001, 'grad_norm': 0.7375909133054507, 'learning_rate': 1.1043622308117063e-06, 'rewards/accuracy_reward': 0.36666667759418486, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.844444477558136, 'rewards/cosine_scaled_reward': 0.036993000144138935, 'reward': 1.2481041848659515, 'reward_std': 1.0289975732564927, 'completion_length': 829.4000122070313, 'kl': 0.0028339385986328124, 'epoch': 0.01}
@asirgogogo
Copy link

@hellen9527 #235

试试这个?

@tenacioustommy
Copy link

i have the same problem

@HarveyYi
Copy link

HarveyYi commented Feb 8, 2025

I have same problem.😭

@hellen9527
Copy link
Author

@hellen9527 #235

试试这个?

可能还不是这个问题,你的是格式奖励是0,但是你的loss不是0,而且我发现更奇怪的是,训练400步之后loss还能上升,感觉我这里的成了loss在上升的优化,不知道是不是损失函数,那个人忘了用1去减还是怎样。。这是我最新的,我还没停止,想再多看一会儿会怎样

{'loss': 0.0007, 'grad_norm': 1.660251103195432, 'learning_rate': 2.2915516289342906e-06, 'rewards/accuracy_reward': 0.33333334028720857, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.7444444686174393, 'rewards/cosine_scaled_reward': 0.08663841746747494, 'reward': 1.164416253566742, 'reward_std': 0.9080875471234322, 'completion_length': 834.3333465576172, 'kl': 0.0183746337890625, 'epoch': 0.01}
{'loss': 0.0019, 'grad_norm': 0.7700014026503217, 'learning_rate': 2.3191606847045835e-06, 'rewards/accuracy_reward': 0.43333334624767306, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.8777778089046478, 'rewards/cosine_scaled_reward': 0.17203281931579112, 'reward': 1.483143937587738, 'reward_std': 1.190664093196392, 'completion_length': 743.1666870117188, 'kl': 0.04857177734375, 'epoch': 0.01}
{'loss': 0.0005, 'grad_norm': 0.8288001079549833, 'learning_rate': 2.346769740474876e-06, 'rewards/accuracy_reward': 0.4333333373069763, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.9333333492279052, 'rewards/cosine_scaled_reward': 0.18143641203641891, 'reward': 1.548103114962578, 'reward_std': 0.45796665027737615, 'completion_length': 790.233349609375, 'kl': 0.0115997314453125, 'epoch': 0.01}
{'loss': 0.0005, 'grad_norm': 0.7728047380844649, 'learning_rate': 2.374378796245169e-06, 'rewards/accuracy_reward': 0.5333333492279053, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.866666704416275, 'rewards/cosine_scaled_reward': 0.2526581108570099, 'reward': 1.6526582002639771, 'reward_std': 1.2655969619750977, 'completion_length': 854.6666870117188, 'kl': 0.0123199462890625, 'epoch': 0.01}

@xx-Jiangwen
Copy link

same problem

@hellen9527 hellen9527 changed the title When I run the GRPO demo, I find that loss is always 0!!! Why does the loss start at 0 when I train GRPO, and then possibly increase? Feb 8, 2025
@hellen9527
Copy link
Author

hellen9527 commented Feb 8, 2025

I tried lowering the version of math-verify to 0.3.3 (it was originally 0.5.2), and only kept the format and accuracy rewards, but it still doesn't work; the loss is still 0 initially.

{'loss': 0.0, 'grad_norm': 0.6244989316226042, 'learning_rate': 5.521811154058532e-08, 'rewards/accuracy_reward': 0.2000000037252903, 'rewards/format_reward': 0.0, 'reward': 0.2000000037252903, 'reward_std': 0.28750002533197405, 'completion_length': 860.433349609375, 'kl': 0.00044727325439453125, 'epoch': 0.0}
{'loss': 0.0, 'grad_norm': 0.484736920892034, 'learning_rate': 1.1043622308117064e-07, 'rewards/accuracy_reward': 0.17500000447034836, 'rewards/format_reward': 0.0, 'reward': 0.17500000447034836, 'reward_std': 0.30477996468544005, 'completion_length': 925.0916809082031, 'kl': 0.000531768798828125, 'epoch': 0.0}
{'loss': 0.0, 'grad_norm': 0.7098745400325474, 'learning_rate': 1.6565433462175596e-07, 'rewards/accuracy_reward': 0.1833333380520344, 'rewards/format_reward': 0.0, 'reward': 0.1833333380520344, 'reward_std': 0.24354272335767746, 'completion_length': 883.6416839599609, 'kl': 0.0005475044250488281, 'epoch': 0.0}
...
{'loss': 0.0001, 'grad_norm': 0.33847296259138065, 'learning_rate': 1.049144119271121e-06, 'rewards/accuracy_reward': 0.3416666738688946, 'rewards/format_reward': 0.0, 'reward': 0.3416666738688946, 'reward_std': 0.3187273934483528, 'completion_length': 872.9750152587891, 'kl': 0.0026445388793945312, 'epoch': 0.01}
{'loss': 0.0001, 'grad_norm': 0.6942921374667554, 'learning_rate': 1.1043622308117063e-06, 'rewards/accuracy_reward': 0.3250000081956387, 'rewards/format_reward': 0.0, 'reward': 0.3250000081956387, 'reward_std': 0.32954234778881075, 'completion_length': 831.808349609375, 'kl': 0.0024944305419921874, 'epoch': 0.01}
{'loss': 0.0002, 'grad_norm': 0.8375045623392503, 'learning_rate': 1.1595803423522917e-06, 'rewards/accuracy_reward': 0.31666667461395265, 'rewards/format_reward': 0.0, 'reward': 0.31666667461395265, 'reward_std': 0.2693626120686531, 'completion_length': 812.1083557128907, 'kl': 0.006044578552246094, 'epoch': 0.01}

@saidineshpola
Copy link

With the exception of the format reward, other rewards are getting better

 {'loss': 0.0, 'grad_norm': 0.007852829992771149, 'learning_rate': 3.448275862068966e-06, 'completion_length': 903.7671875, 'rewards/accuracy_reward': 0.191015625, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.6334635481238365, 'rewards/cosine_scaled_reward': -0.1339177769375965, 'reward': 0.6905613947659731, 'reward_std': 0.5082616737112403, 'kl': 9.812116622924805e-05, 'epoch': 0.02}

                                                       
{'loss': 0.0001, 'grad_norm': 0.006408413872122765, 'learning_rate': 6.896551724137932e-06, 'completion_length': 899.99296875, 'rewards/accuracy_reward': 0.2171875, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.6636718828231096, 'rewards/cosine_scaled_reward': -0.10014793432783335, 'reward': 0.7807114448398351, 'reward_std': 0.4804459346458316, 'kl': 0.0012892663478851318, 'epoch': 0.04}

                                                       
{'loss': 0.0005, 'grad_norm': 0.007318104617297649, 'learning_rate': 1.0344827586206898e-05, 'completion_length': 854.7265625, 'rewards/accuracy_reward': 0.32265625, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.7664062581956387, 'rewards/cosine_scaled_reward': 0.0210849943687208, 'reward': 1.1101475007832051, 'reward_std': 0.4497886072844267, 'kl': 0.01324615478515625, 'epoch': 0.05}

                                                       
{'loss': 0.0009, 'grad_norm': 0.005526963155716658, 'learning_rate': 1.3793103448275863e-05, 'completion_length': 791.9078125, 'rewards/accuracy_reward': 0.4265625, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.9080729261040688, 'rewards/cosine_scaled_reward': 0.14467546985251828, 'reward': 1.4793108977377414, 'reward_std': 0.4111370643600821, 'kl': 0.0230438232421875, 'epoch': 0.07}

{'loss': 0.0013, 'grad_norm': 0.0053334906697273254, 'learning_rate': 1.7241379310344828e-05, 'completion_length': 789.725, 'rewards/accuracy_reward': 0.43671875, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.9627604238688946, 'rewards/cosine_scaled_reward': 0.171499810856767, 'reward': 1.5709789715707303, 'reward_std': 0.38927115853875877, 'kl': 0.03206634521484375, 'epoch': 0.09}

9%|▉ | 25/283 [8:04:04<82:58:09, 1157.71s/it]

@ZhenweiAn
Copy link

I use the modified format reward in @hellen9527 #235, But the loss is still strange.
{'loss': 0.0, 'grad_norm': 1.673892617225647, 'learning_rate': 1.0714285714285716e-06, 'rewards/accuracy_reward': 0.6413265191018581, 'rewards/format_reward': 0.0005102040711790323, 'rewards/reasoning_steps_reward': 0.28860543891787527, 'rewards/cosine_scaled_reward': 0.3494693774729967, 'reward': 1.279911534488201, 'reward_std': 0.8032788023352623, 'completion_length': 750.2198791503906, 'kl': 0.00041866302490234375, 'epoch': 0.04}

{'loss': 0.0, 'grad_norm': 0.9771971106529236, 'learning_rate': 2.142857142857143e-06, 'rewards/accuracy_reward': 0.6280612092465162, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.29370747953653337, 'rewards/cosine_scaled_reward': 0.3368528074584901, 'reward': 1.2586214922368526, 'reward_std': 0.7501622267067433, 'completion_length': 806.3540603637696, 'kl': 0.0010107040405273437, 'epoch': 0.07}

{'loss': 0.0001, 'grad_norm': 0.4975500702857971, 'learning_rate': 2.999485987463336e-06, 'rewards/accuracy_reward': 0.7127550840377808, 'rewards/format_reward': 0.0010204081423580646, 'rewards/reasoning_steps_reward': 0.3047619042918086, 'rewards/cosine_scaled_reward': 0.4018472107127309, 'reward': 1.420384594798088, 'reward_std': 0.7048569574952126, 'completion_length': 756.1489646911621, 'kl': 0.003369712829589844, 'epoch': 0.11}

{'loss': 0.2135, 'grad_norm': 1.6367053985595703, 'learning_rate': 2.981532510892707e-06, 'rewards/accuracy_reward': 0.7147959008812904, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.3149659845978022, 'rewards/cosine_scaled_reward': 0.4177613776177168, 'reward': 1.4475232884287834, 'reward_std': 0.6804138027131558, 'completion_length': 805.1270263671875, 'kl': 5.331846427917481, 'epoch': 0.15}

{'loss': 0.0003, 'grad_norm': 0.5887525677680969, 'learning_rate': 2.9382296023022897e-06, 'rewards/accuracy_reward': 0.7341836579144001, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.3403061218559742, 'rewards/cosine_scaled_reward': 0.4471289239823818, 'reward': 1.521618703007698, 'reward_std': 0.6430120587348938, 'completion_length': 785.0775329589844, 'kl': 0.006349372863769531, 'epoch': 0.19}

@thorinf
Copy link

thorinf commented Feb 8, 2025

per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)

https://github.com/huggingface/trl/blob/84d73fd00b188721e28bd9a18ad38f100114dbda/trl/trainer/grpo_trainer.py#L627

If you are using the GRPO trainer then the old policy is in effect updated every step, this means you just use a detached version of the current policy. The resultant probability ratio will always be 1. By definition of GRPO the advantage is standardised, so its expectation is 0. So the expectation of the probability ratio multiplied by the advantage will also always be zero.

Although the loss is zero, there are still gradients in this case.

The increase may be from KL increasing as you move away from the original distribution.

@hellen9527
Copy link
Author

per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)

https://github.com/huggingface/trl/blob/84d73fd00b188721e28bd9a18ad38f100114dbda/trl/trainer/grpo_trainer.py#L627

If you are using the GRPO trainer then the old policy is in effect updated every step, this means you just use a detached version of the current policy. The resultant probability ratio will always be 1. By definition of GRPO the advantage is standardised, so its expectation is 0. So the expectation of the probability ratio multiplied by the advantage will also always be zero.

Although the loss is zero, there are still gradients in this case.

The increase may be from KL increasing as you move away from the original distribution.

How should I modify it? I saw the explanation saying that the purpose of keeping it always 1 here is to ensure that the loss reduction is fully allocated to the advantage function. I'm using the latest main branch of trl installed via python setup.py install. What should I change?

@hellen9527
Copy link
Author

After switching to the latest training script, I noticed that the initial loss is still 0, but both format_reward and accuracy_reward are slowly increasing. It seems these two are normal, but why is the loss still abnormal? Or does the loss log not matter at all?

{'loss': 0.0, 'grad_norm': 0.9503731084601761, 'learning_rate': 5.521811154058532e-07, 'rewards/accuracy_reward': 0.275, 'rewards/format_reward': 0.01875, 'rewards/reasoning_steps_reward': 0.6427083551883698, 'rewards/cosine_scaled_reward': -0.03994722058996558, 'reward': 0.896511122584343, 'reward_std': 0.8323002576828002, 'completion_length': 849.65625, 'kl': 8.575916290283203e-05, 'epoch': 0.0}
{'loss': 0.0, 'grad_norm': 0.4910426002806111, 'learning_rate': 6.626173384870238e-07, 'rewards/accuracy_reward': 0.1625, 'rewards/format_reward': 0.01875, 'rewards/reasoning_steps_reward': 0.6739583507180213, 'rewards/cosine_scaled_reward': -0.14418444326147437, 'reward': 0.7110238954424858, 'reward_std': 0.7402671471238136, 'completion_length': 887.85, 'kl': 0.00010432004928588867, 'epoch': 0.0}
{'loss': 0.0, 'grad_norm': 0.37581753419329067, 'learning_rate': 7.730535615681944e-07, 'rewards/accuracy_reward': 0.19375, 'rewards/format_reward': 0.0, 'rewards/reasoning_steps_reward': 0.6385416761040688, 'rewards/cosine_scaled_reward': -0.10490049961954355, 'reward': 0.7273911885917187, 'reward_std': 0.6636361941695214, 'completion_length': 897.253125, 'kl': 0.00010716915130615234, 'epoch': 0.0}
...
{'loss': 0.0002, 'grad_norm': 0.320330599760996, 'learning_rate': 1.7669795692987302e-06, 'rewards/accuracy_reward': 0.3625, 'rewards/format_reward': 0.003125, 'rewards/reasoning_steps_reward': 0.8802083492279053, 'rewards/cosine_scaled_reward': 0.0801055665127933, 'reward': 1.3259388893842696, 'reward_std': 0.724829213321209, 'completion_length': 816.81875, 'kl': 0.00401153564453125, 'epoch': 0.01}
{'loss': 0.0002, 'grad_norm': 0.28178907825908484, 'learning_rate': 1.8774157923799008e-06, 'rewards/accuracy_reward': 0.4125, 'rewards/format_reward': 0.003125, 'rewards/reasoning_steps_reward': 0.8916666805744171, 'rewards/cosine_scaled_reward': 0.14045850289985537, 'reward': 1.4477501690387726, 'reward_std': 0.7019044987857341, 'completion_length': 811.49375, 'kl': 0.005374908447265625, 'epoch': 0.01}

@qgallouedec
Copy link
Member

qgallouedec commented Feb 9, 2025

It is completely normal for the loss to start at zero and then increase. Here’s why

The first thing is to understand the GRPO objective, which is formulated as follows:

$$\mathcal{J}_{\text{GRPO}}(\theta) =\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\min \left(\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})} \hat{A}_{i,t}, \text{clip}\left(\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon\right) \hat{A}_{i,t}\right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]\right].$$

where:

  • $G$ is the number of generations per prompt.
  • $o_i$ is the $i$-th generation of the prompt, and $|o_i|$ is the number of tokens in $o_i$.
  • $q$ is the prompt.
  • $\pi_\theta$ is the policy model.
  • $\pi_{\theta_{\text{old}}}$ is the policy model before the update.
  • $\pi_{\text{ref}}$ is the reference policy.
  • $\hat{A}_{i,t}$ is the advantage estimate for the $t$-th token in the $i$-th generation (see below).
  • $\epsilon$ and $\beta$ are hyperparameters.

Note

Here, what interests us is the absolute value of the loss. Therefore, the gradient-related parts of the terms can be ignored.

To simplify, let’s assume we only perform one exploration step per iteration (which is the standard implementation of GRPO). Consequently, $\pi_{\theta_{\text{old}}} = \pi_{\theta}$. The objective naturally simplifies to:

$$\mathcal{J}_{\text{GRPO}}(\theta) =\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\min \left(\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta}(o_{i,t} | q, o_{i,< t})} \hat{A}_{i,t}, \text{clip}\left(\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta}(o_{i,t} | q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon\right) \hat{A}_{i,t}\right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]\right].$$ $$=\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\min \left(\hat{A}_{i,t}, \text{clip}\left(1, 1 - \epsilon, 1 + \epsilon\right) \hat{A}_{i,t}\right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]\right].$$ $$=\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\min \left(\hat{A}_{i,t}, \hat{A}_{i,t}\right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]\right].$$ $$=\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]\right].$$ $$=\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\hat{A}_{i,t} - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].$$

Remember that the advantage does not depend on $t$, contrary to what the notation might suggest, so:

$$\frac{1}{|o_i|} \sum_{t=1}^{|o_i|}\hat{A}_{i,t}=\frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\hat{A}_{i}=\hat{A}_{i}$$

Moreover, $\hat{A}_{t}$ is normalized, which means:

$$\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\hat{A}_{t} = 0$$

Therefore:

$$\mathcal{J}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].$$

In other words, in absolute terms, the loss is equal to the average KL divergence multiplied by $\beta$.

Now, since the reference policy and the policy are initially equal, this is why the loss starts at zero. Training causes the policy to diverge from the initial reference policy, which is why the loss increases.

Finally, this is entirely consistent with the equations. 🤗

@hellen9527
Copy link
Author

hellen9527 commented Feb 10, 2025

It is completely normal for the loss to start at zero and then increase. Here’s why

...

In other words, in absolute terms, the loss is equal to the average KL divergence multiplied by β .

Now, since the reference policy and the policy are initially equal, this is why the loss starts at zero. Training causes the policy to diverge from the initial reference policy, which is why the loss increases.

Finally, this is entirely consistent with the equations. 🤗

@qgallouedec Thank you so much! I was even wondering if there was something wrong with my environment setup or code version. If this is normal, that's great! It means I can focus more on the actual training results.

@Jefferyy-Peng
Copy link

Jefferyy-Peng commented Feb 16, 2025

@qgallouedec, thanks for your clear explanation. However, I still wonder why GRPO only performs one exploration step per iteration, which is different from PPO. Could you elaborate on that? Furthermore, I refer to Algorithm 1 in the DeepSeekMath paper, by exploration step you mean which loop exactly? Is it the outermost iteration, step, or inner loop GRPO iteration?

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants