diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 109cc470e..b06880a2f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,9 +2,9 @@ name: Build on: push: - branches: [ master ] + branches: [ main ] pull_request: - branches: [ master ] + branches: [ main ] jobs: build: diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index c526f564e..ead6ac39b 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -35,8 +35,10 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 2.3 # value term weight - scale_reward: True - clip_reward: 10 + scale_reward: "running" # False | "ref" | "running" estimate against which to scale rewards + ref_mean: null + ref_std: null # rescale rewards with this deviation + cliprange_reward: 10 gen_kwargs: max_length: 48 # LM max sample gen length min_length: 48 # LM min sample gen length diff --git a/configs/test_config.yml b/configs/test_config.yml index abf9e6df1..d26e24e9a 100644 --- a/configs/test_config.yml +++ b/configs/test_config.yml @@ -35,6 +35,10 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 1.0 # value term weight + scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards + cliprange_reward: 10 + ref_mean: null + ref_std: null gen_kwargs: max_length: 48 # LM max sample gen length min_length: 48 # LM min sample gen length diff --git a/tests/test_ppo.py b/tests/test_ppo.py index a4f7e591c..a68ac3500 100644 --- a/tests/test_ppo.py +++ b/tests/test_ppo.py @@ -1,6 +1,7 @@ import unittest from trlx.data.configs import TRLConfig from trlx.model.nn.ppo_models import GPTHydraHeadWithValueModel +from trlx.utils.modeling import RunningMoments from transformers import AutoTokenizer import torch @@ -44,3 +45,22 @@ def test_forward(self): logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() self.assertEqual(hs_diff, 0) self.assertEqual(logits_diff, 0) + +class TestStatistics(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.m = RunningMoments() + cls.a1 = torch.arange(100, dtype=float) + cls.a2 = torch.ones(100, dtype=float) + cls.a3 = torch.exp(torch.arange(10, dtype=float)) + cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) + + def test_running_moments(self): + assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) + + a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) + assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) + assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) diff --git a/trlx/model/nn/ppo_models.py b/trlx/model/nn/ppo_models.py index 2e4129ea6..29302ae0e 100644 --- a/trlx/model/nn/ppo_models.py +++ b/trlx/model/nn/ppo_models.py @@ -111,8 +111,10 @@ class PPOConfig(MethodConfig): cliprange: float cliprange_value: float vf_coef: float - scale_reward: bool - clip_reward: float + scale_reward: str + ref_mean: Optional[float] + ref_std: Optional[float] + cliprange_reward: float gen_kwargs: dict def get_advantages_and_returns( diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 291e09e02..02d55ccbb 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -47,8 +47,8 @@ def __init__( self.rl_model.metric_fn = metric_fn self.running = RunningMoments() - self.ref_mean = None - self.ref_std = None + self.ref_mean = self.rl_model.config.method.ref_mean + self.ref_std = self.rl_model.config.method.ref_std def score(self, samples): """ @@ -84,19 +84,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): scores = torch.as_tensor(self.score(texts), device=samples.device) stats["exp_score_time"] = time() - exp_score_time + # store statistics of the initial rollout as reference if self.ref_mean is None: self.ref_mean, self.ref_std = scores.mean(), scores.std() all_scores_mean, all_scores_std = self.running.update(scores) - stats["exp_scores_mean"] = all_scores_mean stats["exp_scores_std"] = all_scores_std stats["running_mean"] = self.running.mean stats["running_std"] = self.running.std - if self.rl_model.config.method.scale_reward: + if self.rl_model.config.method.scale_reward == "running": scores /= self.running.std + elif self.rl_model.config.method.scale_reward == "ref": + scores /= self.ref_std - clip_reward = self.rl_model.config.method.clip_reward + clip_reward = self.rl_model.config.method.cliprange_reward if clip_reward: scores = torch.clip(scores, -clip_reward, clip_reward) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 3fe30c330..97a368ebd 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -91,12 +91,13 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]: delta = xs_mean - self.mean tot_count = self.count + xs_count - m_a = self.var * self.count - m_b = xs_var * xs_count - m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum self.mean += delta * xs_count / tot_count - self.var = m_2 / tot_count + self.var = tot_sum / tot_count self.std = (self.var * tot_count / (tot_count - 1)).sqrt() self.count = tot_count