diff --git a/examples/scripts/04-ppo-sentiment.py b/examples/scripts/04-ppo-sentiment.py deleted file mode 100644 index a8e1d909541..00000000000 --- a/examples/scripts/04-ppo-sentiment.py +++ /dev/null @@ -1,139 +0,0 @@ -import torch -import wandb -import time -import os -from tqdm import tqdm -import numpy as np -import pandas as pd -tqdm.pandas() - -from datasets import load_dataset - -from transformers import AutoTokenizer, pipeline - -from trl import AutoModelForCausalLMWithValueHead -from trl.ppo import PPOTrainer -from trl.core import build_bert_batch_from_txt, listify_batch - -config = { - "model_name": "lvwerra/gpt2-imdb", - "cls_model_name": "lvwerra/distilbert-imdb", - "steps": 20000, - "batch_size": 256, - "forward_batch_size": 16, - "ppo_epochs": 4, - "txt_in_min_len": 2, - "txt_in_max_len": 8, - "txt_out_min_len": 4, - "txt_out_max_len": 16, - "lr": 1.41e-5, - "init_kl_coef":0.2, - "target": 6, - "horizon":10000, - "gamma":1, - "lam":0.95, - "cliprange": .2, - "cliprange_value":.2, - "vf_coef":.1, -} - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -pipe_device = 0 if torch.cuda.is_available() else -1 - -wandb.init(name='run-42', project='gpt2-test', config=config) - -# load imdb with datasets -ds = load_dataset('imdb', split='train') -ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'}) -ds = ds.filter(lambda x: len(x["review"])>200, batched=False) - -sent_kwargs = { - "return_all_scores": True, - "function_to_apply": "none", - "batch_size": config["forward_batch_size"] -} - -sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=pipe_device) - -gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config['model_name']) -gpt2_model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(config['model_name']) - -gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name']) -gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token - -wandb.watch(gpt2_model, log='all') - -gpt2_model.to(device) -gpt2_model_ref.to(device) - -class LengthSampler: - def __init__(self, min_value, max_value): - self.values = list(range(min_value, max_value)) - def __call__(self): - return np.random.choice(self.values) - -input_size = LengthSampler(config["txt_in_min_len"], config["txt_in_max_len"]) -output_size = LengthSampler(config["txt_out_min_len"], config["txt_out_max_len"]) - -def tokenize(sample): - sample["tokens"] = gpt2_tokenizer.encode(sample["review"])[:input_size()] - sample["query"] = gpt2_tokenizer.decode(sample["tokens"]) - return sample - -ds = ds.map(tokenize, batched=False) - -gen_kwargs = { - "min_length":-1, - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": gpt2_tokenizer.eos_token_id -} - -def collater(data): - return dict((key, [d[key] for d in data]) for key in data[0]) - -dataloader = torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], collate_fn=collater) - -ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config) - -total_ppo_epochs = int(np.ceil(config["steps"]/config['batch_size'])) - -for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(dataloader))): - logs, timing = dict(), dict() - t0 = time.time() - query_tensors = [torch.tensor(t).long().to(device) for t in batch["tokens"]] - - #### Get response from gpt2 - t = time.time() - response_tensors = [] - for i in range(config['batch_size']): - gen_len = output_size() - response = gpt2_model.generate(query_tensors[i].unsqueeze(dim=0), - max_new_tokens=gen_len, **gen_kwargs) - response_tensors.append(response.squeeze()[-gen_len:]) - batch['response'] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors] - timing['time/get_response'] = time.time()-t - - #### Compute sentiment score - t = time.time() - texts = [q + r for q,r in zip(batch['query'], batch['response'])] - pipe_outputs = sentiment_pipe(texts, **sent_kwargs) - rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device) - timing['time/get_sentiment_preds'] = time.time()-t - - #### Run PPO step - t = time.time() - stats = ppo_trainer.step(query_tensors, response_tensors, rewards) - timing['time/optimization'] = time.time()-t - - #### Log everything - timing['time/epoch'] = time.time()-t0 - table_rows = [list(r) for r in zip(batch['query'], batch['response'], rewards.cpu().tolist())] - logs.update({'game_log': wandb.Table(columns=['query', 'response', 'reward'], rows=table_rows)}) - logs.update(timing) - logs.update(stats) - logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy() - logs['env/reward_std'] = torch.std(rewards).cpu().numpy() - logs['env/reward_dist'] = rewards.cpu().numpy() - wandb.log(logs) \ No newline at end of file diff --git a/examples/scripts/ppo-sentiment.py b/examples/scripts/ppo-sentiment.py new file mode 100644 index 00000000000..b7ab3c7015e --- /dev/null +++ b/examples/scripts/ppo-sentiment.py @@ -0,0 +1,150 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from tqdm import tqdm +tqdm.pandas() + +from transformers import pipeline, AutoTokenizer +from datasets import load_dataset + +from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead +from trl.core import LengthSampler + +######################################################################## +# This is a fully working simple example to use trl with accelerate. +# +# This example fine-tunes a GPT2 model on the IMDB dataset using PPO +# (proximal policy optimization). +# in any of the following settings (with the same script): +# - single CPU or single GPU +# - multi GPUS (using PyTorch distributed mode) +# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2) +# - fp16 (mixed-precision) or fp32 (normal precision) +# +# To run it in each of these various modes, first initialize the accelerate +# configuration with `accelerate config` +# +######################################################################## + +# We first define the configuration of the experiment, defining the model, the dataset, +# the training parameters, and the PPO parameters. +# Check the default arguments in the `PPOConfig` class for more details. +config = PPOConfig( + model_name="lvwerra/gpt2-imdb", + learning_rate=1.41e-5, +) + +# We then define the arguments to pass to the sentiment analysis pipeline. +# We set `return_all_scores` to True to get the sentiment score for each token. +sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": config.forward_batch_size +} + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + dataset_name (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + tokenizer.pad_token = tokenizer.eos_token + # load imdb with datasets + ds = load_dataset(dataset_name, split='train') + ds = ds.rename_columns({'text': 'review'}) + ds = ds.filter(lambda x: len(x["review"])>200, batched=False) + + input_size = LengthSampler(input_min_text_length, input_max_text_length) + + def tokenize(sample): + sample["input_ids"] = tokenizer.encode(sample["review"])[:input_size()] + sample["query"] = tokenizer.decode(sample["input_ids"]) + return sample + + ds = ds.map(tokenize, batched=False) + ds.set_format(type='torch') + return ds + +# We retrieve the dataloader by calling the `build_dataset` function. +dataset = build_dataset(config) + +def collater(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + +# Now let's build the model, the reference model, and the tokenizer. +model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) +ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) +tokenizer = AutoTokenizer.from_pretrained(config.model_name) + +# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. +# only for this model. +tokenizer.pad_token = tokenizer.eos_token + +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collater) + +# We then build the sentiment analysis pipeline, passing the model name and the +# sentiment analysis pipeline arguments. Let's also make sure to set the device +# to the same device as the PPOTrainer. +device = ppo_trainer.accelerator.device +if ppo_trainer.accelerator.num_processes == 1: + device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug +sentiment_pipe = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=device) + +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + "min_length":-1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id +} +output_min_length = 4 +output_max_length = 16 +output_length_sampler = LengthSampler(output_min_length, output_max_length) + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch['input_ids'] + + #### Get response from gpt2 + response_tensors = [] + for query in query_tensors: + gen_len = output_length_sampler() + generation_kwargs["max_new_tokens"] = gen_len + response = ppo_trainer.generate(query, **generation_kwargs) + response_tensors.append(response.squeeze()[-gen_len:]) + batch['response'] = [tokenizer.decode(r.squeeze()) for r in response_tensors] + + #### Compute sentiment score + texts = [q + r for q,r in zip(batch['query'], batch['response'])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[1]["score"]).to(device) for output in pipe_outputs] + + #### Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index dccba71ddd4..21478a8574b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,6 @@ datasets==1.17.0 torch>=1.4.0 tqdm transformers +accelerate wandb==0.10.20 matplotlib==3.5.1 diff --git a/setup.py b/setup.py index 0ed4fae480c..10157b55400 100644 --- a/setup.py +++ b/setup.py @@ -22,8 +22,8 @@ requirements = cfg.get('requirements','').split() extras = { - "test" : ["pytest","pytest-xdist",], - "dev" : ["pytest","pytest-xdist", "black", "isort", "flake8>=3.8.3"], + "test" : ["pytest","pytest-xdist","accelerate", "datasets", "wandb"], + "dev" : ["pytest","pytest-xdist", "black", "isort", "flake8>=3.8.3", "accelerate", "datasets", "wandb"], } lic = licenses[cfg['license']] min_python = cfg['min_python'] diff --git a/tests/test_gpt2_model.py b/tests/test_gpt2_model.py index 762628d0aec..bff533facc0 100644 --- a/tests/test_gpt2_model.py +++ b/tests/test_gpt2_model.py @@ -4,9 +4,20 @@ from transformers import GPT2Tokenizer from trl import AutoModelForCausalLMWithValueHead -from trl.gpt2 import respond_to_batch +from trl.core import respond_to_batch -from trl.ppo import PPOTrainer +from trl import PPOTrainer, PPOConfig + +class DummyDataset(torch.utils.data.Dataset): + def __init__(self, query_data, response_data): + self.query_data = query_data + self.response_data = response_data + + def __len__(self): + return len(self.query_data) + + def __getitem__(self, idx): + return self.query_data[idx], self.response_data[idx] def test_gpt2_model(): @@ -16,8 +27,8 @@ def test_gpt2_model(): gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # initialize trainer - ppo_config = {"batch_size": 1, "forward_batch_size": 1} - ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **ppo_config) + ppo_config = {"batch_size": 2, "forward_batch_size": 1, "log_with_wandb": False} + ppo_config = PPOConfig(**ppo_config) # encode a query query_txt = "This morning I went to the " @@ -28,12 +39,23 @@ def test_gpt2_model(): assert response_tensor.shape == (1, 20) response_txt = gpt2_tokenizer.decode(response_tensor[0, :]) - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0)] + # create a dummy dataset + min_length = min(len(query_tensor[0]), len(response_tensor[0])) + dummy_dataset = DummyDataset([query_tensor[:, :min_length].squeeze(0) for _ in range(2)], [response_tensor[:, :min_length].squeeze(0) for _ in range(2)]) + dummy_dataloader = torch.utils.data.DataLoader( + dummy_dataset, batch_size=2, shuffle=True + ) + ppo_trainer = PPOTrainer(config=ppo_config, model=gpt2_model, ref_model=gpt2_model_ref, tokenizer=gpt2_tokenizer, dataset=dummy_dataset) + dummy_dataloader = ppo_trainer.dataloader # train model with ppo - train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + # train model + train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) + break EXPECTED_STATS = [ "objective/kl", diff --git a/trl/__init__.py b/trl/__init__.py index 148bc0aa96b..2a5d2bf61f7 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -1,3 +1,4 @@ __version__ = "0.1.1" -from .models import AutoModelForCausalLMWithValueHead \ No newline at end of file +from .models import AutoModelForCausalLMWithValueHead +from .trainer import PPOTrainer, PPOConfig \ No newline at end of file diff --git a/trl/core.py b/trl/core.py index 30ca719a1ab..26851dc534c 100644 --- a/trl/core.py +++ b/trl/core.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence +from transformers import top_k_top_p_filtering import collections import numpy as np @@ -122,4 +123,27 @@ def build_bert_batch_from_txt(text_list, tokenizer, device): padded_tensors = torch.cat(padded_tensors) attention_masks = torch.cat(attention_masks) - return padded_tensors, attention_masks \ No newline at end of file + return padded_tensors, attention_masks + +def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0): + """Sample text from language model.""" + input_ids = queries + for i in range(txt_len): + # Get Logits + outputs = model(input_ids) + next_token_logits = outputs[0][:, -1, :] + next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + # Sample + probs = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) + input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) + return input_ids[:, -txt_len:] + +class LengthSampler: + """ + Samples a length + """ + def __init__(self, min_value, max_value): + self.values = list(range(min_value, max_value)) + def __call__(self): + return np.random.choice(self.values) \ No newline at end of file diff --git a/trl/gpt2.py b/trl/gpt2.py deleted file mode 100644 index 827c1031720..00000000000 --- a/trl/gpt2.py +++ /dev/null @@ -1,183 +0,0 @@ -# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01-gpt2-with-value-head.ipynb (unless otherwise specified). - -__all__ = ['CausalLMOutputWithCrossAttentions', 'ValueHead', 'GPT2HeadWithValueModel', 'respond_to_batch'] - -# Cell - -from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel -from transformers import top_k_top_p_filtering -from transformers.modeling_outputs import ModelOutput -from torch import nn -from torch.nn import Identity -import torch.nn.functional as F -import torch -from dataclasses import dataclass -from typing import Optional, Tuple - -# Cell -@dataclass -class CausalLMOutputWithCrossAttentions(ModelOutput): - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor]] = None - value: Optional[torch.FloatTensor] = None - -# Cell - -class ValueHead(nn.Module): - """The ValueHead class implements a head for GPT2 that returns a scalar for each output token.""" - def __init__(self, config): - super().__init__() - self.detach_head = False - self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last" - if self.summary_type == "attn": - raise NotImplementedError - - self.summary = Identity() - if hasattr(config, "summary_use_proj") and config.summary_use_proj: - if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: - num_classes = config.num_labels - else: - num_classes = config.hidden_size - self.summary = nn.Linear(config.hidden_size, num_classes) - - self.activation = Identity() - if hasattr(config, "summary_activation") and config.summary_activation == "tanh": - self.activation = nn.Tanh() - - self.first_dropout = Identity() - if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: - self.first_dropout = nn.Dropout(config.summary_first_dropout) - - self.last_dropout = Identity() - if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: - self.last_dropout = nn.Dropout(config.summary_last_dropout) - - self.flatten = nn.Flatten() - - def forward(self, hidden_states, cls_index=None): - if self.detach_head: - output = hidden_states.detach() - else: - output = hidden_states - output = self.first_dropout(output) - output = self.summary(output) - output = self.activation(output) - output = self.last_dropout(output) - - return output - -# Cell - -class GPT2HeadWithValueModel(GPT2PreTrainedModel): - """The GPT2HeadWithValueModel class implements a GPT2 language model with a secondary, scalar head.""" - def __init__(self, config): - super().__init__(config) - config.num_labels = 1 - self.transformer = GPT2Model(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.v_head = ValueHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.lm_head - - def detach_value_head(self): - self.v_head.detach_head = True - - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - mc_token_ids=None, - lm_labels=None, - mc_labels=None, - return_dict=False, - output_attentions=False, - output_hidden_states=False, - use_cache=True, - ): - loss=None - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - ) - - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - value = self.v_head(hidden_states).squeeze(-1) - - - if not return_dict: - outputs = (lm_logits, loss, value,) - return outputs - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - value=value, - ) - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - -# Cell - -def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0): - """Sample text from language model.""" - input_ids = queries - for i in range(txt_len): - # Get Logits - outputs = model(input_ids) - next_token_logits = outputs[0][:, -1, :] - next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - # Sample - probs = F.softmax(next_token_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1).squeeze(1) - input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) - return input_ids[:, -txt_len:] \ No newline at end of file diff --git a/trl/models/__init__.py b/trl/models/__init__.py index 198061e0b88..96512528b87 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -11,4 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .modeling_vhead import AutoModelForCausalLMWithValueHead \ No newline at end of file +from .modeling_vhead import AutoModelForCausalLMWithValueHead +from .modeling_base import PreTrainedModelWrapper + +SUPPORTED_ARCHITECTURES = ( + AutoModelForCausalLMWithValueHead, +) \ No newline at end of file diff --git a/trl/ppo.py b/trl/ppo.py deleted file mode 100644 index 918ca1e5ee2..00000000000 --- a/trl/ppo.py +++ /dev/null @@ -1,310 +0,0 @@ -# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02-ppo.ipynb (unless otherwise specified). - -__all__ = ['AdaptiveKLController', 'FixedKLController', 'PPOTrainer'] - -# Cell -import numpy as np -import torch.nn.functional as F -from torch.optim import Adam -import torch -import collections -import time -import random - -from transformers import DataCollatorForLanguageModeling - -from .core import (logprobs_from_logits, - whiten, - clip_by_value, - entropy_from_logits, - flatten_dict, - average_torch_dicts, - stats_to_np, - stack_dicts, - add_suffix, - WANDB_PADDING) - -# Cell - -class AdaptiveKLController: - """ - Adaptive KL controller described in the paper: - https://arxiv.org/pdf/1909.08593.pdf - """ - def __init__(self, init_kl_coef, target, horizon): - self.value = init_kl_coef - self.target = target - self.horizon = horizon - - def update(self, current, n_steps): - target = self.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult - -# Cell - -class FixedKLController: - """Fixed KL controller.""" - def __init__(self, kl_coef): - self.value = kl_coef - - def update(self, current, n_steps): - pass - -# Cell - -class PPOTrainer: - """ - The PPO_trainer uses Proximal Policy Optimization to optimise language models. - """ - - default_params = { - "lr": 1.41e-5, - "adap_kl_ctrl": True, - "init_kl_coef":0.2, - "target": 6, - "horizon":10000, - "gamma":1, - "lam":0.95, - "cliprange": .2, - "cliprange_value":.2, - "vf_coef":.1, - "batch_size": 256, - "forward_batch_size": 16, - "ppo_epochs": 4, - } - - def __init__(self, model, ref_model, tokenizer, **ppo_params): - """ - Initialize PPOTrainer. - - Args: - model (torch.model): Hugging Face transformer GPT2 model with value head - ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty - tokenizer (tokenizer): Hugging Face tokenizer - ppo_params (dict or None): PPO parameters for training. Can include following keys: - 'lr' (float): Adam learning rate, default: 1.41e-5 - 'batch_size' (int): Number of samples per optimisation step, default: 256 - 'forward_batch_size' (int): Number of samples forward passed through model at a time, default: 16 - 'ppo_epochs' (int): Number of optimisation epochs per batch of samples, default: 4 - 'gamma' (float)): Gamma parameter for advantage calculation, default: 1. - 'lam' (float): Lambda parameter for advantage calcualation, default: 0.95 - 'cliprange_value' (float): Range for clipping values in loss calculation, default: 0.2 - 'cliprange' (float): Range for clipping in PPO policy gradient loss, default: 0.2 - 'vf_coef' (float): Scaling factor for value loss, default: 0.1 - 'adap_kl_ctrl' (bool): Use adaptive KL control, otherwise linear, default: True - 'init_kl_coef' (float): Initial KL penalty coefficient (used for adaptive and linear control), default: 0.2 - 'target' (float): Target KL value for adaptive KL control, default: 6.0 - 'horizon' (float): Horizon for adaptive KL control, default: 10000 - - """ - self.ppo_params = self.default_params - self.ppo_params.update(ppo_params) - - self.ref_model = ref_model - self.model = model - self.tokenizer = tokenizer - self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) - - self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr']) - - if self.ppo_params['adap_kl_ctrl']: - self.kl_ctl = AdaptiveKLController(self.ppo_params['init_kl_coef'], - self.ppo_params['target'], - self.ppo_params['horizon']) - else: - self.kl_ctl = FixedKLController(self.ppo_params['init_kl_coef']) - - - def step(self, queries, responses, scores): - """ - Run a PPO optimisation step. - - args: - queries (List): List of tensors containing the encoded queries, shape [query_length] - responses (List): List of tensors containing the encoded responses, shape [response_length] - scores (List): tensor containing the scores, shape [batch_size] - - returns: - train_stats (dict): a summary of the training statistics - """ - - bs = self.ppo_params['batch_size'] - assert bs == len(queries), f"Batch size ({bs}) does not match number of examples ({len(queries)})" - - timing = dict() - t0 = time.time() - - response_lengths = [len(r) for r in responses] - - t = time.time() - logprobs, ref_logprobs, values = self.batched_forward_pass(queries, responses) - timing['time/ppo/forward_pass'] = time.time()-t - - t = time.time() - rewards, non_score_reward = self.compute_rewards(scores, logprobs, ref_logprobs) - timing['time/ppo/compute_rewards'] = time.time()-t - - t = time.time() - all_stats = [] - idxs = list(range(bs)) - for _ in range(self.ppo_params['ppo_epochs']): - random.shuffle(idxs) - for i in range(bs): - idx = idxs[i] - train_stats = self.train_minibatch(logprobs[idx].unsqueeze(0), values[idx].unsqueeze(0), - rewards[idx].unsqueeze(0), queries[idx].unsqueeze(0), - responses[idx].unsqueeze(0), - torch.cat([queries[idx],responses[idx]]).unsqueeze(0)) - all_stats.append(train_stats) - timing['time/ppo/optimize_step'] = time.time()-t - - t = time.time() - train_stats = stack_dicts(all_stats) - - # reshape advantages/ratios such that they are not averaged. - train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0) - train_stats['policy/advantages'] = torch.nan_to_num(train_stats['policy/advantages'], WANDB_PADDING) - train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0) - - stats = self.record_step_stats(scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs, - non_score_reward=non_score_reward, train_stats=train_stats, - kl_coef=self.kl_ctl.value) - stats = stats_to_np(stats) - timing['time/ppo/calc_stats'] = time.time()-t - - self.kl_ctl.update(stats['objective/kl'], self.ppo_params['batch_size']) - - timing['time/ppo/total'] = time.time()-t0 - stats.update(timing) - return stats - - def batched_forward_pass(self, queries, responses): - """Calculate model outputs in multiple batches.""" - bs = self.ppo_params['batch_size'] - fbs = self.ppo_params['forward_batch_size'] - all_logprobs = [] - all_ref_logprobs = [] - all_values = [] - - for i in range(int(bs/fbs)): - query_batch = queries[i*fbs:(i+1)*fbs] - response_batch = responses[i*fbs:(i+1)*fbs] - input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])["input_ids"] - with torch.no_grad(): - logits, _, v = self.model(input_ids) - ref_logits, _, _ = self.ref_model(input_ids) - logprobs = logprobs_from_logits(logits[:,:-1,:], input_ids[:,1:]) - ref_logprobs = logprobs_from_logits(ref_logits[:,:-1,:], input_ids[:,1:]) - for j in range(fbs): - start = len(query_batch[j])-1 - end = len(query_batch[j]) + len(response_batch[j])-1 - all_values.append(v[j, start-1:end-1]) - all_logprobs.append(logprobs[j, start:end]) - all_ref_logprobs.append(ref_logprobs[j, start:end]) - return all_logprobs, all_ref_logprobs, all_values - - def train_minibatch(self, logprobs, values, rewards, query, response, model_input): - """Train one PPO minibatch""" - loss_p, loss_v, train_stats = self.loss(logprobs, values, rewards, query, response, model_input) - loss = loss_p + loss_v - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - return train_stats - - def compute_rewards(self, scores, logprobs, ref_logprobs): - """Compute per token rewards from scores and KL-penalty.""" - rewards, non_score_rewards = [], [] - for score, logprob, ref_logprob in zip(scores, logprobs, ref_logprobs): - kl = logprob - ref_logprob - non_score_reward = -self.kl_ctl.value * kl - non_score_rewards.append(non_score_reward) - reward = non_score_reward.clone() - reward[-1] += score - rewards.append(reward) - return rewards, non_score_rewards - - def loss(self, old_logprobs, values, rewards, query, response, model_input): - """Calculate policy and value losses.""" - lastgaelam = 0 - advantages_reversed = [] - gen_len = response.shape[1] - - for t in reversed(range(gen_len)): - nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 - delta = rewards[:, t] + self.ppo_params['gamma'] * nextvalues - values[:, t] - lastgaelam = delta + self.ppo_params['gamma'] * self.ppo_params['lam'] * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) - - returns = advantages + values - advantages = whiten(advantages) - advantages = advantages.detach() - - logits, _, vpred = self.model(model_input) - logprob = logprobs_from_logits(logits[:,:-1,:], model_input[:, 1:]) - - #only the generation part of the values/logprobs is needed - logprob, vpred = logprob[:, -gen_len:], vpred[:,-gen_len-1:-1] - - vpredclipped = clip_by_value(vpred, - values - self.ppo_params["cliprange_value"], - values + self.ppo_params["cliprange_value"]) - - vf_losses1 = (vpred - returns)**2 - vf_losses2 = (vpredclipped - returns)**2 - vf_loss = .5 * torch.mean(torch.max(vf_losses1, vf_losses2)) - vf_clipfrac = torch.mean(torch.gt(vf_losses2, vf_losses1).double()) - - ratio = torch.exp(logprob - old_logprobs) - - - pg_losses = -advantages * ratio - pg_losses2 = -advantages * torch.clamp(ratio, - 1.0 - self.ppo_params['cliprange'], - 1.0 + self.ppo_params['cliprange']) - - pg_loss = torch.mean(torch.max(pg_losses, pg_losses2)) - pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses).double()) - - loss = pg_loss + self.ppo_params['vf_coef'] * vf_loss - - entropy = torch.mean(entropy_from_logits(logits)) - approxkl = .5 * torch.mean((logprob - old_logprobs)**2) - policykl = torch.mean(logprob - old_logprobs) - return_mean, return_var = torch.mean(returns), torch.var(returns) - value_mean, value_var = torch.mean(values), torch.var(values) - - stats = dict( - loss=dict(policy=pg_loss, value=vf_loss, total=loss), - policy=dict(entropy=entropy, approxkl=approxkl,policykl=policykl, clipfrac=pg_clipfrac, - advantages=advantages, advantages_mean=torch.mean(advantages), ratio=ratio), - returns=dict(mean=return_mean, var=return_var), - val=dict(vpred=torch.mean(vpred), error=torch.mean((vpred - returns) ** 2), - clipfrac=vf_clipfrac, mean=value_mean, var=value_var), - ) - return pg_loss, self.ppo_params['vf_coef'] * vf_loss, flatten_dict(stats) - - - def record_step_stats(self, kl_coef, **data): - """Record training step statistics.""" - kl_list = [logprobs-ref_logprobs for logprobs, ref_logprobs in zip(data['logprobs'], data['ref_logprobs'])] - mean_kl = torch.mean(torch.stack([torch.sum(kl) for kl in kl_list])) - mean_entropy = torch.mean(torch.stack([torch.sum(-log_probs) for log_probs in data['logprobs']])) - mean_non_score_reward =torch.mean(torch.stack([torch.sum(non_score_reward) for non_score_reward in data['non_score_reward']])) - stats = { - 'objective/kl': mean_kl, - 'objective/kl_dist': kl_list, - 'objective/logprobs': data['logprobs'], - 'objective/ref_logprobs': data['ref_logprobs'], - 'objective/kl_coef': kl_coef, - 'objective/entropy': mean_entropy, - 'ppo/mean_non_score_reward': mean_non_score_reward, - } - - for k, v in data['train_stats'].items(): - stats[f'ppo/{k}'] = torch.mean(v, axis=0) - stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var'] - return stats diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py new file mode 100644 index 00000000000..4fc4a1b280f --- /dev/null +++ b/trl/trainer/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseTrainer +from .utils import AdaptiveKLController, FixedKLController +from .ppo_trainer import PPOTrainer +from .ppo_config import PPOConfig \ No newline at end of file diff --git a/trl/trainer/base.py b/trl/trainer/base.py new file mode 100644 index 00000000000..25dbd16ac8a --- /dev/null +++ b/trl/trainer/base.py @@ -0,0 +1,39 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class BaseTrainer(object): + r""" + Base class for all trainers - this base class implements the basic functions that we + need for a trainer. + + The trainer needs to have the following functions: + - step: takes in a batch of data and performs a step of training + - loss: takes in a batch of data and returns the loss + - compute_rewards: takes in a batch of data and returns the rewards + - _build_models_and_tokenizer: builds the models and tokenizer + - _build_dataset: builds the dataset + Each user is expected to implement their own trainer class that inherits from this base + if they want to use a new training algorithm. + """ + def __init__(self, config): + self.config = config + + def step(self, *args): + raise NotImplementedError("Not implemented") + + def loss(self, *args): + raise NotImplementedError("Not implemented") + + def compute_rewards(self, *args): + raise NotImplementedError("Not implemented") \ No newline at end of file diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py new file mode 100644 index 00000000000..4793ed7af78 --- /dev/null +++ b/trl/trainer/ppo_config.py @@ -0,0 +1,107 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from dataclasses import dataclass +from typing import Optional + +@dataclass +class PPOConfig(object): + """ + Configuration class for PPOTrainer + + Args: + model_name (`str`, *optional*, defaults to `None`): + Name of model to use - used only for tracking purposes + steps (`int`, *optional*, defaults to 20000): + Number of training steps + learning_rate (`float`, *optional*, defaults to 1.41e-5): + Adam learning rate + adap_kl_ctrl (`bool`, *optional*, defaults to True): + Use adaptive KL control, otherwise linear + init_kl_coef (`float`, *optional*, defaults to 0.2): + Initial KL penalty coefficient (used for adaptive and linear control) + target (`float`, *optional*, defaults to 6): + Target KL value for adaptive KL control + horizon (`float`, *optional*, defaults to 10000): + Horizon for adaptive KL control + gamma (`float`, *optional*, defaults to 1): + Gamma parameter for advantage calculation + lam (`float`, *optional*, defaults to 0.95): + Lambda parameter for advantage calcualation + cliprange (`float`, *optional*, defaults to 0.2): + Range for clipping in PPO policy gradient loss + cliprange_value (`float`, *optional*, defaults to 0.2): + Range for clipping values in loss calculation + vf_coef (`float`, *optional*, defaults to 0.1): + Scaling factor for value loss + batch_size (`int`, *optional*, defaults to 256): + Number of samples per optimisation step + forward_batch_size (`int`, *optional*, defaults to 16): + Number of samples forward passed through model at a time + ppo_epochs (`int`, *optional*, defaults to 4): + Number of optimisation epochs per batch of samples + remove_unused_columns (`bool`, *optional*, defaults to True): + Remove unused columns from the dataset if `datasets.Dataset` is used + log_with_wandb (`bool`, *optional*, defaults to True): + Log with wandb + wandb_project (`str`, *optional*, defaults to "trl"): + Name of wandb project + """ + def __init__( + self, + model_name: Optional[str] = None, + steps: Optional[int] = 20000, + learning_rate: Optional[float] = 1e-5, + adap_kl_ctrl: Optional[bool] = True, + init_kl_coef: Optional[float] = 0.2, + target: Optional[float] = 6, + horizon: Optional[float] = 10000, + gamma: Optional[float] = 1, + lam: Optional[float] = 0.95, + cliprange: Optional[float] = 0.2, + cliprange_value: Optional[float] = 0.2, + vf_coef: Optional[float] = 0.1, + batch_size: Optional[int] = 256, + forward_batch_size: Optional[int] = 16, + ppo_epochs: Optional[int] = 4, + remove_unused_columns: Optional[bool] = True, + log_with_wandb: Optional[bool] = True, + wandb_project: Optional[str] = "trl", + ): + self.model_name = model_name + self.steps = steps + self.learning_rate = learning_rate + self.adap_kl_ctrl = adap_kl_ctrl + self.init_kl_coef = init_kl_coef + self.target = target + self.horizon = horizon + self.gamma = gamma + self.lam = lam + self.cliprange = cliprange + self.cliprange_value = cliprange_value + self.vf_coef = vf_coef + self.batch_size = batch_size + self.forward_batch_size = forward_batch_size + self.ppo_epochs = ppo_epochs + self.remove_unused_columns = remove_unused_columns + self.log_with_wandb = log_with_wandb + self.wandb_project = wandb_project + + self.total_ppo_epochs = int(np.ceil(steps/batch_size)) + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return output_dict \ No newline at end of file diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py new file mode 100644 index 00000000000..0d7b48f3595 --- /dev/null +++ b/trl/trainer/ppo_trainer.py @@ -0,0 +1,611 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from packaging import version +import inspect +import warnings +from typing import Any, List, Optional, Union +from accelerate import Accelerator +import datasets +from datasets import Dataset + +from torch.optim import Adam +import torch +import time +import random +import wandb + +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer, PreTrainedTokenizerFast + +from trl.core import (logprobs_from_logits, + whiten, + clip_by_value, + entropy_from_logits, + flatten_dict, + stats_to_np, + stack_dicts, + WANDB_PADDING) +from trl.trainer import BaseTrainer, AdaptiveKLController, FixedKLController +from trl.models import PreTrainedModelWrapper, SUPPORTED_ARCHITECTURES + + +class PPOTrainer(BaseTrainer): + """ + The PPOTrainer uses Proximal Policy Optimization to optimise language models. + """ + def __init__( + self, + config, + model: PreTrainedModelWrapper, + ref_model: PreTrainedModelWrapper, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + dataset: Union[torch.utils.data.Dataset, Dataset], + data_collator = None, + ): + """ + Initialize PPOTrainer. + + Args: + config (`PPOConfig`): + Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details. + model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a value head. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for KL penalty + tokenizer (`transformers.PreTrainedTokenizer`): + Hugging Face tokenizer + dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. + data_collator (Optional[function]): + Data collator function. + """ + super().__init__(config) + + # Step 1: Initialize Accelerator + self.accelerator = Accelerator(log_with="wandb") + + # Step 2: Initialize model, tokenizer, and dataloader + if not isinstance(model, PreTrainedModelWrapper): + raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}") + self.model = model + + if not isinstance(ref_model, PreTrainedModelWrapper): + raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(ref_model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}") + self.ref_model = ref_model + + if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)): + raise ValueError("tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast") + self.tokenizer = tokenizer + + if not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)): + raise ValueError("dataloader must be a torch.utils.data.Dataset or datasets.Dataset") + self.dataset = dataset + self._signature_columns = None + self.dataloader = self.prepare_dataloader(self.dataset, data_collator) + + # Step 3: Initialize optimizer and data collator + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + self.optimizer = Adam(self.model.parameters(), lr=self.config.learning_rate) + + if self.config.adap_kl_ctrl: + self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, + self.config.target, + self.config.horizon) + else: + self.kl_ctl = FixedKLController(self.config.init_kl_coef) + + self.model, self.ref_model, self.optimizer, self.data_collator, self.dataloader = self.accelerator.prepare(self.model, self.ref_model, self.optimizer, self.data_collator, self.dataloader) + + # In a distributed setup, only logging needs to be performed on the main process + # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 + self.is_distributed = self.accelerator.distributed_type == "MULTI_GPU" + + # init wandb on the main process: + if self.accelerator.is_main_process and self.config.log_with_wandb: + wandb.init(name='run-42', project=self.config.wandb_project, config=config) + wandb.watch(self.model, log='all') + + def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator = None): + """ + Prepare the dataloader for training. + + Args: + dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. + data_collator (Optional[function]): + Data collator function. + + Returns: + `torch.utils.data.DataLoader`: + PyTorch dataloader + """ + if isinstance(dataset, Dataset): + dataset = self._remove_unused_columns(dataset) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.batch_size, + collate_fn=data_collator, + shuffle=True, + ) + return dataloader + + # Adapted from transformers.Trainer._set_signature_columns_if_needed + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # label => sentiment | we need query and response for logging purpose + self._signature_columns += list(set(["label", "query", "response"])) + + # Adapted from transformers.Trainer._remove_unused_columns + def _remove_unused_columns(self, dataset: "Dataset"): + if not self.config.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + + columns = [k for k in signature_columns if k in dataset.column_names] + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + + def generate(self, query_tensor: torch.Tensor, **generation_kwargs): + """ + Generate response given query. + + Args: + query_tensor (`torch.LongTensor`): + A tensor of shape (`batch_size`, `seq_len`) containing query tokens. + gen_kwargs (dict[str, Any]): + Keyword arguments for generation. + + Returns: + response (`torch.LongTensor`): + A tensor of shape (`batch_size`, `gen_len`) containing response tokens. + """ + response = self.accelerator.unwrap_model(self.model).generate(query_tensor.unsqueeze(dim=0),**generation_kwargs) + + return response + + + def _step_safety_checker( + self, + batch_size: int, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + ): + """ + Check if the input data is valid for training. + + Args: + batch_size (int): + Batch size + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + """ + for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must tensors - got {type(tensor_list[0])}") + if len(tensor_list) != batch_size: + raise ValueError(f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}") + + def step( + self, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + ): + """ + Run a PPO optimisation step. + + Args: + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + + Returns: + train_stats (dict[str, Any]): + a summary of the training statistics + """ + + bs = self.config.batch_size + + self._step_safety_checker(bs, queries, responses, scores) + + timing = dict() + t0 = time.time() + + t = time.time() + + logprobs, ref_logprobs, values = self.batched_forward_pass(queries, responses) + timing['time/ppo/forward_pass'] = time.time()-t + + t = time.time() + rewards, non_score_reward = self.compute_rewards(scores, logprobs, ref_logprobs) + timing['time/ppo/compute_rewards'] = time.time()-t + + t = time.time() + all_stats = [] + idxs = list(range(bs)) + for _ in range(self.config.ppo_epochs): + random.shuffle(idxs) + for i in range(bs): + idx = idxs[i] + train_stats = self.train_minibatch(logprobs[idx].unsqueeze(0), values[idx].unsqueeze(0), + rewards[idx].unsqueeze(0), queries[idx].unsqueeze(0), + responses[idx].unsqueeze(0), + torch.cat([queries[idx],responses[idx]]).unsqueeze(0)) + all_stats.append(train_stats) + timing['time/ppo/optimize_step'] = time.time()-t + + t = time.time() + train_stats = stack_dicts(all_stats) + + # reshape advantages/ratios such that they are not averaged. + train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0) + train_stats['policy/advantages'] = torch.nan_to_num(train_stats['policy/advantages'], WANDB_PADDING) + train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0) + + stats = self.record_step_stats(scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs, + non_score_reward=non_score_reward, train_stats=train_stats, + kl_coef=self.kl_ctl.value) + # Gather/Reduce stats from all processes + if self.is_distributed: + stats = self.gather_stats(stats) + stats = stats_to_np(stats) + timing['time/ppo/calc_stats'] = time.time()-t + + # Update the KL control - multiply the batch_size by the number of processes + self.kl_ctl.update(stats['objective/kl'], self.config.batch_size * self.accelerator.num_processes) + + # Log the total ppo time + timing['time/ppo/total'] = time.time()-t0 + stats.update(timing) + return stats + + def gather_stats(self, stats): + """ + Gather stats from all processes. Useful in the context of distributed training. + + Args: + stats (dict[str, Any]): + a dictionary of stats to be gathered. The stats should contain torch tensors. + + Returns: + stats (dict[str, Any]): + a dictionary of stats with the tensors gathered. + """ + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + for k, v in stats.items(): + if isinstance(v, torch.Tensor): + dist.all_reduce(v, dist.ReduceOp.SUM) + v /= self.accelerator.num_processes + stats[k] = v + return stats + + def batched_forward_pass(self, queries: torch.Tensor, responses: torch.Tensor): + """ + Calculate model outputs in multiple batches. + + Args: + queries (`torch.LongTensor`): + List of tensors containing the encoded queries, shape (`batch_size`, `query_length`) + responses (`torch.LongTensor`): + List of tensors containing the encoded responses, shape (`batch_size`, `response_length`) + + Returns: + all_logprobs (`torch.FloatTensor`): + List of tensors containing the logprobs, shape (`batch_size`, `response_length`) + all_ref_logprobs (`torch.FloatTensor`): + List of tensors containing the logprobs from the reference model, shape (`batch_size`, `response_length`) + all_values (`torch.FloatTensor`): + List of tensors containing the output from the value head, shape (`batch_size`, `response_length`) + + """ + bs = self.config.batch_size + fbs = self.config.forward_batch_size + all_logprobs = [] + all_ref_logprobs = [] + all_values = [] + + for i in range(int(bs/fbs)): + query_batch = queries[i*fbs:(i+1)*fbs] + response_batch = responses[i*fbs:(i+1)*fbs] + input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])["input_ids"] + with torch.no_grad(): + logits, _, v = self.model(input_ids) + ref_logits, _, _ = self.ref_model(input_ids) + logprobs = logprobs_from_logits(logits[:,:-1,:], input_ids[:,1:]) + ref_logprobs = logprobs_from_logits(ref_logits[:,:-1,:], input_ids[:,1:]) + for j in range(fbs): + start = len(query_batch[j])-1 + end = len(query_batch[j]) + len(response_batch[j])-1 + all_values.append(v[j, start-1:end-1]) + all_logprobs.append(logprobs[j, start:end]) + all_ref_logprobs.append(ref_logprobs[j, start:end]) + return all_logprobs, all_ref_logprobs, all_values + + def train_minibatch( + self, + logprobs: torch.FloatTensor, + values: torch.FloatTensor, + rewards: torch.FloatTensor, + query: torch.LongTensor, + response: torch.LongTensor, + model_input: torch.LongTensor, + ): + """ + Train one PPO minibatch + + Args: + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape [batch_size, response_length] + values (`torch.FloatTensor`): + Values of the value head, shape [batch_size, response_length] + rewards (`torch.FloatTensor`): + Rewards from the reward model, shape [batch_size, response_length] + query (`torch.LongTensor`): + Encoded queries, shape [batch_size, query_length] + response (`torch.LongTensor`): + Encoded responses, shape [batch_size, response_length] + model_input (`torch.LongTensor`): + Concatenated queries and responses, shape [batch_size, query_length+response_length] + + Returns: + train_stats (dict[str, `torch.Tensor`]): + Dictionary of training statistics + """ + loss_p, loss_v, train_stats = self.loss(logprobs, values, rewards, query, response, model_input) + loss = loss_p + loss_v + self.optimizer.zero_grad() + self.accelerator.backward(loss) + t = time.time() + self.optimizer.step() + train_stats['time/ppo/optimizer_step'] = torch.Tensor([time.time()-t]).to(self.accelerator.device) + return train_stats + + + def compute_rewards( + self, + scores: torch.FloatTensor, + logprobs: torch.FloatTensor, + ref_logprobs: torch.FloatTensor + ): + """ + Compute per token rewards from scores and KL-penalty. + + Args: + scores (`torch.FloatTensor`): + Scores from the reward model, shape (`batch_size`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + ref_logprobs (`torch.FloatTensor`): + Log probabilities of the reference model, shape (`batch_size`, `response_length`) + """ + rewards, non_score_rewards = [], [] + for score, logprob, ref_logprob in zip(scores, logprobs, ref_logprobs): + kl = logprob - ref_logprob + non_score_reward = -self.kl_ctl.value * kl + non_score_rewards.append(non_score_reward) + reward = non_score_reward.clone() + reward[-1] += score + rewards.append(reward) + return rewards, non_score_rewards + + def loss( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + rewards: torch.FloatTensor, + query: torch.LongTensor, + response: torch.LongTensor, + model_input: torch.LongTensor, + ): + """ + Calculate policy and value losses. + + Args: + old_logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + values (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `hidden_dim`) + rewards (`torch.FloatTensor`): + Rewards from the reward model, shape (`batch_size`) + query (`torch.LongTensor`): + Encoded queries, shape (`batch_size`, `query_length`) + response (`torch.LongTensor`): + Encoded responses, shape (`batch_size`, `response_length`) + model_input (`torch.LongTensor`): + Concatenated queries and responses, shape (`batch_size`, `query_length+response_length`) + """ + lastgaelam = 0 + advantages_reversed = [] + gen_len = response.shape[1] + + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) + + returns = advantages + values + advantages = whiten(advantages) + advantages = advantages.detach() + + logits, _, vpred = self.model(model_input) + logprob = logprobs_from_logits(logits[:,:-1,:], model_input[:, 1:]) + + #only the generation part of the values/logprobs is needed + logprob, vpred = logprob[:, -gen_len:], vpred[:,-gen_len-1:-1] + + vpredclipped = clip_by_value(vpred, + values - self.config.cliprange_value, + values + self.config.cliprange_value) + + vf_losses1 = (vpred - returns)**2 + vf_losses2 = (vpredclipped - returns)**2 + vf_loss = .5 * torch.mean(torch.max(vf_losses1, vf_losses2)) + vf_clipfrac = torch.mean(torch.gt(vf_losses2, vf_losses1).double()) + + ratio = torch.exp(logprob - old_logprobs) + + pg_losses = -advantages * ratio + pg_losses2 = -advantages * torch.clamp(ratio, + 1.0 - self.config.cliprange, + 1.0 + self.config.cliprange) + + pg_loss = torch.mean(torch.max(pg_losses, pg_losses2)) + pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses).double()) + + loss = pg_loss + self.config.vf_coef * vf_loss + + entropy = torch.mean(entropy_from_logits(logits)) + approxkl = .5 * torch.mean((logprob - old_logprobs)**2) + policykl = torch.mean(logprob - old_logprobs) + return_mean, return_var = torch.mean(returns), torch.var(returns) + value_mean, value_var = torch.mean(values), torch.var(values) + + stats = dict( + loss=dict(policy=pg_loss, value=vf_loss, total=loss), + policy=dict(entropy=entropy, approxkl=approxkl,policykl=policykl, clipfrac=pg_clipfrac, + advantages=advantages, advantages_mean=torch.mean(advantages), ratio=ratio), + returns=dict(mean=return_mean, var=return_var), + val=dict(vpred=torch.mean(vpred), error=torch.mean((vpred - returns) ** 2), + clipfrac=vf_clipfrac, mean=value_mean, var=value_var), + ) + return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats) + + + def record_step_stats(self, kl_coef: float, **data): + """ + Record training step statistics. + + + Args: + kl_coef (`float`): + KL coefficient + data (`dict`): + Dictionary of training step data + + Returns: + stats (`dict`): + Dictionary of training step statistics + """ + kl_list = [logprobs-ref_logprobs for logprobs, ref_logprobs in zip(data['logprobs'], data['ref_logprobs'])] + mean_kl = torch.mean(torch.stack([torch.sum(kl) for kl in kl_list])) + mean_entropy = torch.mean(torch.stack([torch.sum(-log_probs) for log_probs in data['logprobs']])) + mean_non_score_reward =torch.mean(torch.stack([torch.sum(non_score_reward) for non_score_reward in data['non_score_reward']])) + stats = { + 'objective/kl': mean_kl, + 'objective/kl_dist': kl_list, + 'objective/logprobs': data['logprobs'], + 'objective/ref_logprobs': data['ref_logprobs'], + 'objective/kl_coef': kl_coef, + 'objective/entropy': mean_entropy, + 'ppo/mean_non_score_reward': mean_non_score_reward, + } + + for k, v in data['train_stats'].items(): + stats[f'ppo/{k}'] = torch.mean(v, axis=0) + stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var'] + return stats + + + def log_stats( + self, + stats: dict, + batch: dict, + rewards: List[torch.FloatTensor], + ): + """ + A function that logs all the training stats. Call it at the end of each epoch. + + Args: + stats (dict[str, Any]): + A dictionary of training stats. + batch (dict[str, Any]): + A dictionary of batch data, this containes the queries and responses. + rewards (`List[torch.FloatTensor]`): + A tensor of rewards. + """ + # Log only if we are in the main process + if self.accelerator.is_main_process: + wandb_logs = {} + + # Log stats + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards).to(self.accelerator.device) + + if ("query" not in batch.keys() and "response" not in batch.keys()): + # warn the user that the game logs will not be logged + warnings.warn("The game logs will not be logged because the batch does not contain the keys 'query' and 'response'.") + elif self.config.log_with_wandb: + table_rows = [list(r) for r in zip(batch['query'], batch['response'], rewards.cpu().tolist())] + wandb_logs.update({'game_log': wandb.Table(columns=['query', 'response', 'reward'], rows=table_rows)}) + # All reduce rewards if distributed + if self.is_distributed: + import torch.distributed as dist + + dist.barrier() + + + dist.all_reduce(rewards, op=torch.distributed.ReduceOp.SUM) + rewards /= self.accelerator.num_processes + + if self.config.log_with_wandb: + wandb_logs.update(stats) + wandb_logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy() + wandb_logs['env/reward_std'] = torch.std(rewards).cpu().numpy() + wandb_logs['env/reward_dist'] = rewards.cpu().numpy() + wandb.log(wandb_logs) + else: + stats['env/reward_mean'] = torch.mean(rewards).cpu().numpy() + stats['env/reward_std'] = torch.std(rewards).cpu().numpy() + stats['env/reward_dist'] = rewards.cpu().numpy() + + else: + if self.is_distributed: + import torch.distributed as dist + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards).to(self.accelerator.device) + + dist.barrier() + dist.all_reduce(rewards, op=torch.distributed.ReduceOp.SUM) + diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py new file mode 100644 index 00000000000..4afa1eb8ff3 --- /dev/null +++ b/trl/trainer/utils.py @@ -0,0 +1,38 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + def __init__(self, init_kl_coef, target, horizon): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current, n_steps): + target = self.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + +class FixedKLController: + """Fixed KL controller.""" + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current, n_steps): + pass