From dcabc515d8f1d1139bc0f7efe9636c2c56f6fcd2 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 17 Nov 2022 02:04:48 +0200 Subject: [PATCH 1/5] fix(ppo): optional reward scaling and minibatch advantage whitening --- configs/ppo_config.yml | 2 +- trlx/orchestrator/ppo_orchestrator.py | 27 ++++++++-- trlx/utils/modeling.py | 75 +++++++++++++++++++++++++-- 3 files changed, 95 insertions(+), 9 deletions(-) diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index 38a36cddc..d6ee43c70 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -6,7 +6,7 @@ model: train: seq_length: 48 # Size of LM context - epochs: 1000 # Train for max(epochs, total_steps) + epochs: 100 # Train for max(epochs, total_steps) total_steps: 10000 # Train for max(epochs, total_steps) batch_size: 128 # batch size diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 7f6096aae..b097d95d3 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -8,8 +8,9 @@ from trlx.orchestrator import Orchestrator, register_orchestrator from trlx.pipeline import BasePipeline from trlx.utils import Clock -from trlx.utils.modeling import logprobs_from_logits +from trlx.utils.modeling import logprobs_from_logits, RunningMoments +from time import time import ray from ray.air import session @@ -45,6 +46,10 @@ def __init__( self.rl_model.reward_fn = reward_fn self.rl_model.metric_fn = metric_fn + self.running = RunningMoments() + self.ref_mean = None + self.ref_std = None + def score(self, samples): """ Batched scoring function taking text and generating scalar @@ -66,15 +71,28 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): self.pipeline_iterator = iter(self.pipeline_loader) batch = next(self.pipeline_iterator) + exp_generate_time = time() samples = self.rl_model.generate(**batch) + stats["exp_generate_time"] = time() - exp_generate_time query_tensors = batch.input_ids response_tensors = samples[:, query_tensors.shape[1] :] texts = self.rl_model.tokenizer.batch_decode( samples, skip_special_tokens=True ) - scores = torch.as_tensor(self.score(texts)) - + exp_score_time = time() + scores = torch.as_tensor(self.score(texts), device=samples.device) + stats["exp_score_time"] = time() - exp_score_time + + if self.ref_mean is None: + self.ref_mean, self.ref_std = scores.mean(), scores.std() + all_scores_mean, all_scores_std = self.running.update(scores) + scores /= self.running.std + + stats["exp_scores_mean"] = all_scores_mean + stats["exp_scores_std"] = all_scores_std + stats["running_mean"] = self.running.mean + stats["running_std"] = self.running.std # Precompute logprobs, values all_tokens = torch.cat( (query_tensors.to(samples.device), response_tensors), dim=1 @@ -126,7 +144,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): ] ppo_rl_elements += new_ppo_rl_elements - stats = {"exp_time": exp_time} + stats["kl_ctl_value"] = self.rl_model.kl_ctl.value + stats["exp_time"] = exp_time if not ray.is_initialized(): self.rl_model.accelerator.log(stats, step=iter_count) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 2751edd67..3fe30c330 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -2,12 +2,33 @@ import torch import torch.nn.functional as F +import torch.distributed as dist +from typing import Tuple -def whiten(values, shift_mean=True): - """Whiten values.""" - mean, var = torch.mean(values), torch.var(values) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) +def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]: + """ + Computes element-wise mean and variance of the tensor across processes + """ + sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device) + dist.all_reduce(sum_and_count, dist.ReduceOp.SUM) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum((xs - global_mean) ** 2) + dist.all_reduce(sum_var, dist.ReduceOp.SUM) + global_var = sum_var / count + return global_mean, global_var, count + + +def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor: + """Whitens values""" + if distributed and dist.is_initialized(): + mean, var, _ = get_global_statistics(xs) + else: + var, mean = torch.var_mean(xs) + + whitened = (xs - mean) * torch.rsqrt(var + 1e-8) if not shift_mean: whitened += mean return whitened @@ -34,3 +55,49 @@ def flatten_dict( else: items.append((new_key, v)) return dict(items) + + +def log_stat(stats: dict, name: str, xs: torch.Tensor, mask: torch.Tensor, n: int): + mean = (xs * mask).sum() / n + stats.update( + { + f"{name}/mean": mean, + f"{name}/min": torch.where(mask.bool(), xs, np.inf).min(), + f"{name}/max": torch.where(mask.bool(), xs, -np.inf).max(), + f"{name}/std": torch.sqrt(((xs - mean) * mask).pow(2).sum() / n), + } + ) + + +class RunningMoments: + def __init__(self): + """ + Calculates the running mean and standard deviation of a data stream. Modified version of + https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py + """ + self.mean = 0 + self.std = 1 + self.var = 1 + self.count = 1e-24 + + def update(self, xs: torch.Tensor) -> Tuple[float, float]: + """Updates running moments from batch's moments computed across ranks""" + if dist.is_initialized(): + xs_mean, xs_var, xs_count = get_global_statistics(xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + m_a = self.var * self.count + m_b = xs_var * xs_count + m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count + + self.mean += delta * xs_count / tot_count + self.var = m_2 / tot_count + self.std = (self.var * tot_count / (tot_count - 1)).sqrt() + self.count = tot_count + + return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt() From 267c9d44980e3f151a035636fefc18cdb02848d3 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 17 Nov 2022 14:43:15 +0200 Subject: [PATCH 2/5] feat(ppo): add optional reward clipping --- configs/ppo_config.yml | 2 ++ trlx/model/nn/ppo_models.py | 2 ++ trlx/orchestrator/ppo_orchestrator.py | 9 ++++++++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index d6ee43c70..c526f564e 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -35,6 +35,8 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 2.3 # value term weight + scale_reward: True + clip_reward: 10 gen_kwargs: max_length: 48 # LM max sample gen length min_length: 48 # LM min sample gen length diff --git a/trlx/model/nn/ppo_models.py b/trlx/model/nn/ppo_models.py index 7c6158e3e..2e4129ea6 100644 --- a/trlx/model/nn/ppo_models.py +++ b/trlx/model/nn/ppo_models.py @@ -111,6 +111,8 @@ class PPOConfig(MethodConfig): cliprange: float cliprange_value: float vf_coef: float + scale_reward: bool + clip_reward: float gen_kwargs: dict def get_advantages_and_returns( diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index b097d95d3..291e09e02 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -87,12 +87,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): if self.ref_mean is None: self.ref_mean, self.ref_std = scores.mean(), scores.std() all_scores_mean, all_scores_std = self.running.update(scores) - scores /= self.running.std stats["exp_scores_mean"] = all_scores_mean stats["exp_scores_std"] = all_scores_std stats["running_mean"] = self.running.mean stats["running_std"] = self.running.std + + if self.rl_model.config.method.scale_reward: + scores /= self.running.std + + clip_reward = self.rl_model.config.method.clip_reward + if clip_reward: + scores = torch.clip(scores, -clip_reward, clip_reward) + # Precompute logprobs, values all_tokens = torch.cat( (query_tensors.to(samples.device), response_tensors), dim=1 From 8ae13ac431432e2c46665fa7633baee99af3f3d7 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 17 Nov 2022 22:29:44 +0200 Subject: [PATCH 3/5] chore(ppo): add tests, comments --- configs/ppo_config.yml | 2 +- configs/test_config.yml | 2 ++ tests/test_ppo.py | 20 ++++++++++++++++++++ trlx/model/nn/ppo_models.py | 2 +- trlx/orchestrator/ppo_orchestrator.py | 5 ++++- trlx/utils/modeling.py | 9 +++++---- 6 files changed, 33 insertions(+), 7 deletions(-) diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index c526f564e..88b16359c 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -35,7 +35,7 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 2.3 # value term weight - scale_reward: True + scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards clip_reward: 10 gen_kwargs: max_length: 48 # LM max sample gen length diff --git a/configs/test_config.yml b/configs/test_config.yml index abf9e6df1..24a1927e0 100644 --- a/configs/test_config.yml +++ b/configs/test_config.yml @@ -35,6 +35,8 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 1.0 # value term weight + scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards + clip_reward: 10 gen_kwargs: max_length: 48 # LM max sample gen length min_length: 48 # LM min sample gen length diff --git a/tests/test_ppo.py b/tests/test_ppo.py index a4f7e591c..a68ac3500 100644 --- a/tests/test_ppo.py +++ b/tests/test_ppo.py @@ -1,6 +1,7 @@ import unittest from trlx.data.configs import TRLConfig from trlx.model.nn.ppo_models import GPTHydraHeadWithValueModel +from trlx.utils.modeling import RunningMoments from transformers import AutoTokenizer import torch @@ -44,3 +45,22 @@ def test_forward(self): logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() self.assertEqual(hs_diff, 0) self.assertEqual(logits_diff, 0) + +class TestStatistics(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.m = RunningMoments() + cls.a1 = torch.arange(100, dtype=float) + cls.a2 = torch.ones(100, dtype=float) + cls.a3 = torch.exp(torch.arange(10, dtype=float)) + cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) + + def test_running_moments(self): + assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) + + a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) + assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) + assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) diff --git a/trlx/model/nn/ppo_models.py b/trlx/model/nn/ppo_models.py index 2e4129ea6..0ad8e2be7 100644 --- a/trlx/model/nn/ppo_models.py +++ b/trlx/model/nn/ppo_models.py @@ -111,7 +111,7 @@ class PPOConfig(MethodConfig): cliprange: float cliprange_value: float vf_coef: float - scale_reward: bool + scale_reward: str clip_reward: float gen_kwargs: dict diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 291e09e02..382c9d0c6 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -84,6 +84,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): scores = torch.as_tensor(self.score(texts), device=samples.device) stats["exp_score_time"] = time() - exp_score_time + # store statistics of the initial rollout as reference if self.ref_mean is None: self.ref_mean, self.ref_std = scores.mean(), scores.std() all_scores_mean, all_scores_std = self.running.update(scores) @@ -93,8 +94,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): stats["running_mean"] = self.running.mean stats["running_std"] = self.running.std - if self.rl_model.config.method.scale_reward: + if self.rl_model.config.method.scale_reward == "running": scores /= self.running.std + elif self.rl_model.config.method.scale_reward == "ref": + scores /= self.ref_std clip_reward = self.rl_model.config.method.clip_reward if clip_reward: diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 3fe30c330..97a368ebd 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -91,12 +91,13 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]: delta = xs_mean - self.mean tot_count = self.count + xs_count - m_a = self.var * self.count - m_b = xs_var * xs_count - m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum self.mean += delta * xs_count / tot_count - self.var = m_2 / tot_count + self.var = tot_sum / tot_count self.std = (self.var * tot_count / (tot_count - 1)).sqrt() self.count = tot_count From d4998a6b34860b7c95614d2a072dbe904ae7aa60 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:00:28 +0200 Subject: [PATCH 4/5] fix(github): rename master to main for build --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 109cc470e..b06880a2f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,9 +2,9 @@ name: Build on: push: - branches: [ master ] + branches: [ main ] pull_request: - branches: [ master ] + branches: [ main ] jobs: build: From b400aeb64c69f2f4e02814196e30ccbfe0145b40 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 18 Nov 2022 23:56:14 +0200 Subject: [PATCH 5/5] feat(ppo): add manual reward scaling --- configs/ppo_config.yml | 6 ++++-- configs/test_config.yml | 4 +++- trlx/model/nn/ppo_models.py | 4 +++- trlx/orchestrator/ppo_orchestrator.py | 7 +++---- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index 88b16359c..ead6ac39b 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -35,8 +35,10 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 2.3 # value term weight - scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards - clip_reward: 10 + scale_reward: "running" # False | "ref" | "running" estimate against which to scale rewards + ref_mean: null + ref_std: null # rescale rewards with this deviation + cliprange_reward: 10 gen_kwargs: max_length: 48 # LM max sample gen length min_length: 48 # LM min sample gen length diff --git a/configs/test_config.yml b/configs/test_config.yml index 24a1927e0..d26e24e9a 100644 --- a/configs/test_config.yml +++ b/configs/test_config.yml @@ -36,7 +36,9 @@ method: cliprange_value: 0.2 # clip range vf_coef: 1.0 # value term weight scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards - clip_reward: 10 + cliprange_reward: 10 + ref_mean: null + ref_std: null gen_kwargs: max_length: 48 # LM max sample gen length min_length: 48 # LM min sample gen length diff --git a/trlx/model/nn/ppo_models.py b/trlx/model/nn/ppo_models.py index 0ad8e2be7..29302ae0e 100644 --- a/trlx/model/nn/ppo_models.py +++ b/trlx/model/nn/ppo_models.py @@ -112,7 +112,9 @@ class PPOConfig(MethodConfig): cliprange_value: float vf_coef: float scale_reward: str - clip_reward: float + ref_mean: Optional[float] + ref_std: Optional[float] + cliprange_reward: float gen_kwargs: dict def get_advantages_and_returns( diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 382c9d0c6..02d55ccbb 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -47,8 +47,8 @@ def __init__( self.rl_model.metric_fn = metric_fn self.running = RunningMoments() - self.ref_mean = None - self.ref_std = None + self.ref_mean = self.rl_model.config.method.ref_mean + self.ref_std = self.rl_model.config.method.ref_std def score(self, samples): """ @@ -88,7 +88,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): if self.ref_mean is None: self.ref_mean, self.ref_std = scores.mean(), scores.std() all_scores_mean, all_scores_std = self.running.update(scores) - stats["exp_scores_mean"] = all_scores_mean stats["exp_scores_std"] = all_scores_std stats["running_mean"] = self.running.mean @@ -99,7 +98,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): elif self.rl_model.config.method.scale_reward == "ref": scores /= self.ref_std - clip_reward = self.rl_model.config.method.clip_reward + clip_reward = self.rl_model.config.method.cliprange_reward if clip_reward: scores = torch.clip(scores, -clip_reward, clip_reward)