Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional normalization (cont.) #98

Merged
merged 7 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Build

on:
push:
branches: [ master ]
branches: [ main ]
pull_request:
branches: [ master ]
branches: [ main ]

jobs:
build:
Expand Down
6 changes: 4 additions & 2 deletions configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ method:
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
scale_reward: True
clip_reward: 10
scale_reward: "running" # False | "ref" | "running" estimate against which to scale rewards
ref_mean: null
ref_std: null # rescale rewards with this deviation
cliprange_reward: 10
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
Expand Down
4 changes: 4 additions & 0 deletions configs/test_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ method:
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 1.0 # value term weight
scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards
cliprange_reward: 10
ref_mean: null
ref_std: null
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
Expand Down
20 changes: 20 additions & 0 deletions tests/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from trlx.data.configs import TRLConfig
from trlx.model.nn.ppo_models import GPTHydraHeadWithValueModel
from trlx.utils.modeling import RunningMoments
from transformers import AutoTokenizer
import torch

Expand Down Expand Up @@ -44,3 +45,22 @@ def test_forward(self):
logits_diff = torch.sum(unfrozen_logits - frozen_logits).item()
self.assertEqual(hs_diff, 0)
self.assertEqual(logits_diff, 0)

class TestStatistics(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.m = RunningMoments()
cls.a1 = torch.arange(100, dtype=float)
cls.a2 = torch.ones(100, dtype=float)
cls.a3 = torch.exp(torch.arange(10, dtype=float))
cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float)

def test_running_moments(self):
assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6)

a = torch.hstack((self.a1, self.a2, self.a3, self.a4))
assert torch.isclose(self.m.mean, a.mean(), atol=1e-6)
assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6)
6 changes: 4 additions & 2 deletions trlx/model/nn/ppo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ class PPOConfig(MethodConfig):
cliprange: float
cliprange_value: float
vf_coef: float
scale_reward: bool
clip_reward: float
scale_reward: str
ref_mean: Optional[float]
ref_std: Optional[float]
cliprange_reward: float
gen_kwargs: dict

def get_advantages_and_returns(
Expand Down
12 changes: 7 additions & 5 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
self.rl_model.metric_fn = metric_fn

self.running = RunningMoments()
self.ref_mean = None
self.ref_std = None
self.ref_mean = self.rl_model.config.method.ref_mean
self.ref_std = self.rl_model.config.method.ref_std

def score(self, samples):
"""
Expand Down Expand Up @@ -84,19 +84,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
scores = torch.as_tensor(self.score(texts), device=samples.device)
stats["exp_score_time"] = time() - exp_score_time

# store statistics of the initial rollout as reference
if self.ref_mean is None:
self.ref_mean, self.ref_std = scores.mean(), scores.std()
all_scores_mean, all_scores_std = self.running.update(scores)

stats["exp_scores_mean"] = all_scores_mean
stats["exp_scores_std"] = all_scores_std
stats["running_mean"] = self.running.mean
stats["running_std"] = self.running.std

if self.rl_model.config.method.scale_reward:
if self.rl_model.config.method.scale_reward == "running":
scores /= self.running.std
elif self.rl_model.config.method.scale_reward == "ref":
scores /= self.ref_std

clip_reward = self.rl_model.config.method.clip_reward
clip_reward = self.rl_model.config.method.cliprange_reward
if clip_reward:
scores = torch.clip(scores, -clip_reward, clip_reward)

Expand Down
9 changes: 5 additions & 4 deletions trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
delta = xs_mean - self.mean
tot_count = self.count + xs_count

m_a = self.var * self.count
m_b = xs_var * xs_count
m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count
new_sum = xs_var * xs_count
# correct old_sum deviation accounting for the new mean
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
tot_sum = old_sum + new_sum

self.mean += delta * xs_count / tot_count
self.var = m_2 / tot_count
self.var = tot_sum / tot_count
self.std = (self.var * tot_count / (tot_count - 1)).sqrt()
self.count = tot_count

Expand Down