|
20 | 20 |
|
21 | 21 | import numpy as np
|
22 | 22 | import torch
|
| 23 | +from collections import defaultdict |
23 | 24 |
|
24 | 25 | import verl.utils.torch_functional as verl_F
|
25 | 26 |
|
@@ -106,6 +107,54 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc
|
106 | 107 | return advantages, returns
|
107 | 108 |
|
108 | 109 |
|
| 110 | +# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. |
| 111 | +def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor, |
| 112 | + eos_mask: torch.Tensor, |
| 113 | + index: torch.Tensor, |
| 114 | + epsilon: float = 1e-6): |
| 115 | + """ |
| 116 | + Compute advantage for GRPO, operating only on Outcome reward |
| 117 | + (with only one scalar reward for each response). |
| 118 | + Args: |
| 119 | + token_level_rewards: `(torch.Tensor)` |
| 120 | + shape: (bs, response_length) |
| 121 | + eos_mask: `(torch.Tensor)` |
| 122 | + shape: (bs, response_length) |
| 123 | + |
| 124 | + Returns: |
| 125 | + advantages: `(torch.Tensor)` |
| 126 | + shape: (bs, response_length) |
| 127 | + Returns: `(torch.Tensor)` |
| 128 | + shape: (bs, response_length) |
| 129 | + """ |
| 130 | + response_length = token_level_rewards.shape[-1] |
| 131 | + non_zero_mask = (token_level_rewards != 0) |
| 132 | + scores = (token_level_rewards * non_zero_mask).sum(dim=-1) |
| 133 | + |
| 134 | + id2score = defaultdict(list) |
| 135 | + id2mean = {} |
| 136 | + id2std = {} |
| 137 | + |
| 138 | + with torch.no_grad(): |
| 139 | + bsz = scores.shape[0] |
| 140 | + for i in range(bsz): |
| 141 | + id2score[index[i]].append(scores[i]) |
| 142 | + for idx in id2score: |
| 143 | + if len(id2score[idx]) == 1: |
| 144 | + id2mean[idx] = torch.tensor(0.0) |
| 145 | + id2std[idx] = torch.tensor(1.0) |
| 146 | + elif len(id2score[idx]) > 1: |
| 147 | + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) |
| 148 | + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) |
| 149 | + else: |
| 150 | + raise ValueError(f"no score in prompt index: {idx}") |
| 151 | + for i in range(bsz): |
| 152 | + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) |
| 153 | + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask |
| 154 | + |
| 155 | + return scores, scores |
| 156 | + |
| 157 | + |
109 | 158 | def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
|
110 | 159 | kl = old_log_prob - ref_log_prob
|
111 | 160 | return token_level_scores - kl * kl_ratio
|
@@ -210,6 +259,14 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe
|
210 | 259 | if kl_penalty == "mse":
|
211 | 260 | return 0.5 * (logprob - ref_logprob).square()
|
212 | 261 |
|
| 262 | + # J. Schulman. Approximating kl divergence, 2020. |
| 263 | + # # URL http://joschu.net/blog/kl-approx.html. |
| 264 | + if kl_penalty == 'low_var_kl': |
| 265 | + kl = ref_logprob - logprob |
| 266 | + ratio = torch.exp(kl) |
| 267 | + kld = (ratio - kl - 1).contiguous() |
| 268 | + return torch.clamp(kld, min=-10, max=10) |
| 269 | + |
213 | 270 | if kl_penalty == "full":
|
214 | 271 | # so, here logprob and ref_logprob should contain the logits for every token in vocabulary
|
215 | 272 | raise NotImplementedError
|
|
0 commit comments