Skip to content

Commit

Permalink
Add optional reward scaling (#95)
Browse files Browse the repository at this point in the history
* fix(ppo): optional reward scaling and minibatch advantage whitening

* feat(ppo): add optional reward clipping
  • Loading branch information
maxreciprocate authored Nov 17, 2022
1 parent 90ce1aa commit 12598e9
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 8 deletions.
4 changes: 3 additions & 1 deletion configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model:

train:
seq_length: 48 # Size of LM context
epochs: 1000 # Train for max(epochs, total_steps)
epochs: 100 # Train for max(epochs, total_steps)
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

Expand Down Expand Up @@ -35,6 +35,8 @@ method:
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
scale_reward: True
clip_reward: 10
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
Expand Down
2 changes: 2 additions & 0 deletions trlx/model/nn/ppo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class PPOConfig(MethodConfig):
cliprange: float
cliprange_value: float
vf_coef: float
scale_reward: bool
clip_reward: float
gen_kwargs: dict

def get_advantages_and_returns(
Expand Down
32 changes: 29 additions & 3 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline import BasePipeline
from trlx.utils import Clock
from trlx.utils.modeling import logprobs_from_logits
from trlx.utils.modeling import logprobs_from_logits, RunningMoments

from time import time
import ray
from ray.air import session

Expand Down Expand Up @@ -45,6 +46,10 @@ def __init__(
self.rl_model.reward_fn = reward_fn
self.rl_model.metric_fn = metric_fn

self.running = RunningMoments()
self.ref_mean = None
self.ref_std = None

def score(self, samples):
"""
Batched scoring function taking text and generating scalar
Expand All @@ -66,14 +71,34 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
self.pipeline_iterator = iter(self.pipeline_loader)
batch = next(self.pipeline_iterator)

exp_generate_time = time()
samples = self.rl_model.generate(**batch)
stats["exp_generate_time"] = time() - exp_generate_time

query_tensors = batch.input_ids
response_tensors = samples[:, query_tensors.shape[1] :]
texts = self.rl_model.tokenizer.batch_decode(
samples, skip_special_tokens=True
)
scores = torch.as_tensor(self.score(texts))
exp_score_time = time()
scores = torch.as_tensor(self.score(texts), device=samples.device)
stats["exp_score_time"] = time() - exp_score_time

if self.ref_mean is None:
self.ref_mean, self.ref_std = scores.mean(), scores.std()
all_scores_mean, all_scores_std = self.running.update(scores)

stats["exp_scores_mean"] = all_scores_mean
stats["exp_scores_std"] = all_scores_std
stats["running_mean"] = self.running.mean
stats["running_std"] = self.running.std

if self.rl_model.config.method.scale_reward:
scores /= self.running.std

clip_reward = self.rl_model.config.method.clip_reward
if clip_reward:
scores = torch.clip(scores, -clip_reward, clip_reward)

# Precompute logprobs, values
all_tokens = torch.cat(
Expand Down Expand Up @@ -126,7 +151,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
]
ppo_rl_elements += new_ppo_rl_elements

stats = {"exp_time": exp_time}
stats["kl_ctl_value"] = self.rl_model.kl_ctl.value
stats["exp_time"] = exp_time

if not ray.is_initialized():
self.rl_model.accelerator.log(stats, step=iter_count)
Expand Down
75 changes: 71 additions & 4 deletions trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,33 @@

import torch
import torch.nn.functional as F
import torch.distributed as dist
from typing import Tuple


def whiten(values, shift_mean=True):
"""Whiten values."""
mean, var = torch.mean(values), torch.var(values)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]:
"""
Computes element-wise mean and variance of the tensor across processes
"""
sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device)
dist.all_reduce(sum_and_count, dist.ReduceOp.SUM)
global_sum, count = sum_and_count
global_mean = global_sum / count

sum_var = torch.sum((xs - global_mean) ** 2)
dist.all_reduce(sum_var, dist.ReduceOp.SUM)
global_var = sum_var / count
return global_mean, global_var, count


def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor:
"""Whitens values"""
if distributed and dist.is_initialized():
mean, var, _ = get_global_statistics(xs)
else:
var, mean = torch.var_mean(xs)

whitened = (xs - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
Expand All @@ -34,3 +55,49 @@ def flatten_dict(
else:
items.append((new_key, v))
return dict(items)


def log_stat(stats: dict, name: str, xs: torch.Tensor, mask: torch.Tensor, n: int):
mean = (xs * mask).sum() / n
stats.update(
{
f"{name}/mean": mean,
f"{name}/min": torch.where(mask.bool(), xs, np.inf).min(),
f"{name}/max": torch.where(mask.bool(), xs, -np.inf).max(),
f"{name}/std": torch.sqrt(((xs - mean) * mask).pow(2).sum() / n),
}
)


class RunningMoments:
def __init__(self):
"""
Calculates the running mean and standard deviation of a data stream. Modified version of
https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py
"""
self.mean = 0
self.std = 1
self.var = 1
self.count = 1e-24

def update(self, xs: torch.Tensor) -> Tuple[float, float]:
"""Updates running moments from batch's moments computed across ranks"""
if dist.is_initialized():
xs_mean, xs_var, xs_count = get_global_statistics(xs)
else:
xs_count = xs.numel()
xs_var, xs_mean = torch.var_mean(xs, unbiased=False)

delta = xs_mean - self.mean
tot_count = self.count + xs_count

m_a = self.var * self.count
m_b = xs_var * xs_count
m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count

self.mean += delta * xs_count / tot_count
self.var = m_2 / tot_count
self.std = (self.var * tot_count / (tot_count - 1)).sqrt()
self.count = tot_count

return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt()

0 comments on commit 12598e9

Please sign in to comment.