Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 model #145

Merged
merged 34 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ec00c54
add t5 to trlx
Dec 20, 2022
dacb652
add t5 examples for sentiment
Dec 20, 2022
56a0a3c
add eval for t5
Dec 20, 2022
5feff9f
fix eval
Dec 20, 2022
ccfabde
remove old files
Dec 21, 2022
2674d24
remove bad files
Dec 21, 2022
6e43ea1
remove bad files
Dec 21, 2022
59c2cf5
fix incompatible with gpt model, add summarization code base
Dec 21, 2022
2c133b0
freeze frozen branch
Dec 21, 2022
c9ddfcf
Merge branch 'main' into add_t5
PhungVanDuy Dec 21, 2022
5f38a81
fix evaluation bug t5, add summarization cnn/daily mail example
Dec 25, 2022
17be682
update sentiment example
Dec 27, 2022
2d1a4dc
stable config sentiment
Dec 27, 2022
f9f85ba
add attention mask decoder
Dec 29, 2022
500099f
setting worked - flant5 two unfrozen small rollouts
Dec 31, 2022
b55a4e8
merge newest code from main
Jan 1, 2023
36a74e6
fix head nn, config cnn daily mail, remove sent examples
Jan 2, 2023
6baee0b
fix style, change model_arch_type, truncated tokenizer fixed
Jan 6, 2023
d2082a7
fix style
Jan 6, 2023
d2f6a1d
precommit changes
Jan 6, 2023
eaf9c94
fix ppo state values for t5
Jan 7, 2023
c03313a
Merge branch 'main' into add_t5
PhungVanDuy Jan 7, 2023
93cf3cc
fix style
Jan 7, 2023
8ac399b
remove sentiment example
Jan 7, 2023
fefa62b
fix typo
Jan 7, 2023
5ae1188
fix ppo for causal models, add save best, seperate rollouts/eval args
Jan 7, 2023
ea10837
add ppo sentiment
Jan 7, 2023
84f8b7b
fix rewards typo
Jan 8, 2023
03cc954
Merge branch 'main' into add_t5
PhungVanDuy Jan 8, 2023
347e314
merging with main
Jan 8, 2023
220c8f3
fix style
Jan 8, 2023
a0a43f8
add docstring for gen_kwargs_inference, save best
Jan 9, 2023
b86e3d4
add gen kwargs support for rollouts sampling
Jan 9, 2023
eb0b0cc
Make summarization example self-contained
jon-tow Jan 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main(hparams={}):
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
batch_size=128,
device=device,
)

Expand All @@ -43,13 +43,15 @@ def reward_fn(samples: List[str]) -> List[float]:
return sentiments

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
imdb = load_dataset("imdb", split="train")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
imdb = load_dataset("imdb", split="test")
val_prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

return trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 64,
eval_prompts=val_prompts[0:1000],
PhungVanDuy marked this conversation as resolved.
Show resolved Hide resolved
config=config,
)

Expand Down
Empty file.
53 changes: 53 additions & 0 deletions examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
train:
seq_length: 612
epochs: 100
total_steps: 100000
batch_size: 12

checkpoint_interval: 10000
eval_interval: 500
save_best: False

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
model_path: "google/flan-t5-large"
model_arch_type: "seq2seq"
tokenizer_path: "google/flan-t5-large"
num_layers_unfrozen: 2

optimizer:
name: "adamw"
kwargs:
lr: 1.0e-5
betas: [0.9, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6

scheduler:
name: "cosine_annealing"
kwargs:
T_max: 10000
eta_min: 1.0e-6

method:
name: "ppoconfig"
num_rollouts: 512
chunk_size: 12
ppo_epochs: 4
init_kl_coef: 0.05
target: 6
horizon: 10000
gamma: 0.99
lam: 0.95
cliprange: 0.2
cliprange_value: 0.2
vf_coef: 1.0
scale_reward: False
ref_mean: null
ref_std: null
cliprange_reward: 10
gen_kwargs:
max_new_tokens: 100
83 changes: 83 additions & 0 deletions examples/summarize_daily_cnn/t5_summarize_daily_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
from typing import List

from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer

import trlx
from trlx.data.configs import TRLConfig

try:
import evaluate
except ImportError:
raise ImportError(
"To run this example, please install the `evaluate` and `nltk` packages"
"by running `pip install evaluate`"
)

config_path = os.path.join(
os.path.dirname(__file__), "configs/ppo_config_cnn_daily.yml"
)
config = TRLConfig.load_yaml(config_path)

meteor = evaluate.load("meteor") # use meteor as the reward function

if __name__ == "__main__":

def reward_fn(samples: List[str]):
sep_token = tokenizer.sep_token
articles = [sample.split(sep_token)[0].strip() for sample in samples]
predicted_summaries = [sample.split(sep_token)[1].strip() for sample in samples]
labels = [prompt_label[sample] for sample in articles]
scores = [
meteor.compute(predictions=[summary], references=[label])
for (summary, label) in zip(predicted_summaries, labels)
]
scores = [score["meteor"] for score in scores]
return scores

dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data")

# take 20,000 samples from the training set as prompts for training
prompts = dataset["train"]["article"][0:20000]
summaries = dataset["train"]["highlights"][0:20000]
prompts = ["Summarize: " + prompt for prompt in prompts]

# take 1,000 samples from the validation set as prompts for evaluation
val_prompts = [
"Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]
]
val_summaries = dataset["validation"]["highlights"][0:1000]

# make dictionary of prompts and labels to use for reward function
tokenizer = AutoTokenizer.from_pretrained(config.model.model_path)
tokenizer.padding_side = "left"
tokenizer.truncation_side = "right"
tokenizer.sep_token = "<sep>"
prompt_label = {}
max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]

for i in tqdm(range(len(prompts))):
key = tokenizer.decode(
tokenizer(prompts[i], truncation=True, max_length=max_length)["input_ids"],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = summaries[i]

for i in tqdm(range(len(val_prompts))):
key = tokenizer.decode(
tokenizer(val_prompts[i], truncation=True, max_length=max_length)[
"input_ids"
],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = val_summaries[i]

model = trlx.train(
config.model.model_path,
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=val_prompts,
config=config,
)
10 changes: 9 additions & 1 deletion trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class ModelConfig:
:param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub)
:type tokenizer_path: str

:param model_arch_type: Type of model architecture. Either "causal" or "seq2seq"
:type model_arch_type: str

:param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning.
-1 means all layers are unfrozen.
:type num_layers_unfrozen: int
Expand All @@ -47,8 +50,9 @@ class ModelConfig:

model_path: str
tokenizer_path: str
model_arch_type: str = "causal"
num_layers_unfrozen: int = -1
delta_kwargs: Dict[str, Any] = field(default_factory=dict)
delta_kwargs: Optional[Dict[str, Any]] = None

@classmethod
def from_dict(cls, config: Dict[str, Any]):
Expand Down Expand Up @@ -142,6 +146,9 @@ class TrainConfig:
Only used by AcceleratePPOTrainer.
:type rollout_logging_dir: Optional[str]

:param save_best: Save best model based on mean reward
:type save_best: bool

:param seed: Random seed
:type seed: int
"""
Expand All @@ -163,6 +170,7 @@ class TrainConfig:

checkpoint_dir: str = "ckpts"
rollout_logging_dir: Optional[str] = None
save_best: bool = True
PhungVanDuy marked this conversation as resolved.
Show resolved Hide resolved

trackers: Tuple[str] = ("wandb",)
seed: int = 1000
Expand Down
148 changes: 114 additions & 34 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(

if not hasattr(self.trainer.model, "frozen_head"):
self.ref_model = self.trainer.get_arch(self.trainer.config)
self.ref_model.to(self.trainer.accelerator.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified this works? I recall accelerate freezing up if I started putting multiple models on gpu(though this could've just been the sentiment pipeline we were using for sentiments task)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@PhungVanDuy PhungVanDuy Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree with the larger model we should distribute multiple models on multiple gpus, but for this, I think we should keep it on GPU rather than CPU, they are super slow.


self.trainer.orch = self
self.trainer.reward_fn = reward_fn
Expand Down Expand Up @@ -75,11 +76,26 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
samples = self.trainer.generate(**batch)
stats["time/exp_generate"] = time() - exp_generate_time

query_tensors = batch.input_ids
response_tensors = samples[:, query_tensors.shape[1] :]
if self.trainer.config.model.model_arch_type == "seq2seq":
response_tensors = samples
else:
query_tensors = batch.input_ids
response_tensors = samples[:, query_tensors.shape[1] :]

texts = self.trainer.tokenizer.batch_decode(
samples, skip_special_tokens=True
)

if self.trainer.config.model.model_arch_type == "seq2seq":
articles = self.trainer.tokenizer.batch_decode(
batch.input_ids, skip_special_tokens=True
)
sep_token = self.trainer.tokenizer.sep_token
texts = [
f"{article}{sep_token}{response}"
for article, response in zip(articles, texts)
]

exp_score_time = time()
scores = torch.tensor(
self.score(texts), device=samples.device, dtype=torch.float
Expand All @@ -105,52 +121,117 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
scores = torch.clip(scores, -clip_reward, clip_reward)

# Precompute logprobs, values
all_tokens, attention_mask, position_ids = self.trainer.get_model_inputs(
query_tensors.to(response_tensors.device), response_tensors
)
Comment on lines -108 to -110
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removed the need for:

def get_model_inputs(

We need to remove it if unused before it becomes stale.

with torch.no_grad():
logits, *_, values = self.trainer.model(
all_tokens, attention_mask=attention_mask, position_ids=position_ids
if self.trainer.config.model.model_arch_type == "seq2seq":
response_tensors = response_tensors
attention_mask = batch.attention_mask.to(response_tensors.device)
query_tensors = batch.input_ids.to(response_tensors.device)
with torch.no_grad():
outputs = self.trainer.model(
input_ids=query_tensors,
attention_mask=attention_mask,
decoder_input_ids=response_tensors,
)
logits = outputs.logits
values = outputs.value
if hasattr(self.trainer.model, "frozen_head"):
ref_logits = self.trainer.model.forward_hydra(
input_ids=query_tensors,
attention_mask=attention_mask,
decoder_input_ids=response_tensors,
)
else:
ref_logits = self.ref_model(
input_ids=query_tensors,
attention_mask=attention_mask,
decoder_input_ids=response_tensors,
).logits
else:
all_tokens = torch.cat(
(query_tensors.to(response_tensors.device), response_tensors), dim=1
)
# TODO(dahoas): When hydra model works need to also support generation on hydra head
if hasattr(self.trainer.model, "frozen_head"):
ref_logits = self.trainer.model.forward_hydra(
attention_mask = (
all_tokens.not_equal(self.trainer.tokenizer.pad_token_id)
.long()
.to(all_tokens.device)
)
with torch.no_grad():
logits, *_, values = self.trainer.model(
all_tokens,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=False,
)
else:
ref_logits, _, *_ = self.ref_model(
all_tokens.cpu(),
attention_mask=attention_mask.cpu(),
position_ids=position_ids.cpu(),
)
ref_logits = ref_logits.to(self.trainer.accelerator.device)
# TODO(dahoas): When hydra model works need to also support generation on hydra head
if hasattr(self.trainer.model, "frozen_head"):
ref_logits = self.trainer.model.forward_hydra(
all_tokens,
attention_mask=attention_mask,
return_dict=False,
)
else:
ref_logits, _, *_ = self.ref_model(
all_tokens,
attention_mask=attention_mask,
return_dict=False,
)
ref_logits = ref_logits.to(self.trainer.accelerator.device)

logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:])
ref_logprobs = logprobs_from_logits(
ref_logits[:, :-1, :], all_tokens[:, 1:]
)
if self.trainer.config.model.model_arch_type == "seq2seq":
logprobs = logprobs_from_logits(
logits[:, :-1, :], response_tensors[:, 1:]
)
ref_logprobs = logprobs_from_logits(
ref_logits[:, :-1, :], response_tensors[:, 1:]
)
else:
logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:])
ref_logprobs = logprobs_from_logits(
ref_logits[:, :-1, :], all_tokens[:, 1:]
)

n = samples.shape[0]
values = values.cpu()[:, :-1]
logprobs = logprobs.cpu()
ref_logprobs = ref_logprobs.cpu()
query_tensors = query_tensors.cpu()
response_tensors = response_tensors.cpu()

start = query_tensors.shape[1] - 1
ends = start + attention_mask[:, start:].sum(1)
all_values = [values[ix, start : ends[ix]] for ix in range(n)]
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n)]
if self.trainer.config.model.model_arch_type == "seq2seq":
start = 1 # skip the <s> token
ends = (response_tensors[:, start:] != 0).sum(1)
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n)]
all_values = [values[ix, start - 1 : ends[ix] - 1] for ix in range(n)]
rewards = [
-self.trainer.kl_ctl.value
* (
logprobs[ix, start : ends[ix]]
- ref_logprobs[ix, start : ends[ix]]
)
for ix in range(n)
]
else:
n = samples.shape[0]
values = values.cpu()
logprobs = logprobs.cpu()
ref_logprobs = ref_logprobs.cpu()
query_tensors = query_tensors.cpu()
response_tensors = response_tensors.cpu()
Comment on lines +211 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these lines - these vars are already put on cpu on the lines right before the if-statement

start = (
query_tensors.shape[1] - 1
) # left shift by 1 ref: https://github.com/lvwerra/trl/blob/main/trl/trainer/ppo_trainer.py#L425
ends = start + attention_mask[:, start:].sum(1) - 1
for ix in range(n):
if ends[ix] == all_tokens.shape[1]:
ends[ix] = ends[ix] - 1
all_values = [values[ix, start - 1 : ends[ix] - 1] for ix in range(n)]
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n)]
rewards = -self.trainer.kl_ctl.value * (logprobs - ref_logprobs)
rewards = [rs[start : ends[ix]] for ix, rs in enumerate(rewards)]

# Compute rewards
rewards = -self.trainer.kl_ctl.value * (logprobs - ref_logprobs)
all_rewards = [None] * n

for ix in range(n):
rs = rewards[ix][start : ends[ix]]
rs[-1] = scores[ix]
rs = rewards[ix]
if len(rs) == 0:
rs = torch.tensor([0.0])
Copy link
Collaborator

@maxreciprocate maxreciprocate Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we penalize empty responses? Also do you know how it's possible to have those, except for when max_new_tokens == 0 🤔

Copy link
Collaborator Author

@PhungVanDuy PhungVanDuy Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a complete exception case, but I got it a few times when I ran PPO sentiments. @jon-tow you also faced with this right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah Jon has said that he also experienced it. I wonder if the case of it is unknown it may be a symptom of some other bug elsewhere

Copy link
Collaborator

@jon-tow jon-tow Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reciprocated the only case I can think of that'd lead to an empty response is when len(query) is larger than the generate method's min_length arg, which defaults to 10, and the model so happens to output the eos_token on its first sample. (Note that with causal models the min_length constraint includes the length of the context (query) meaning it won't actually have an effect on the generations if the min condition is already met by the context size).

In such cases, I'm okay with penalizing empty responses as they're uninformative - so long as this is not a bug lol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it is a serious issue. We can just throw an error if this min_length thing comes up. I've never seen this in practice when I set min length correctly. (Perhaps we should add an extra parameter called min_new_length...? We should upstream to HF transformers though)

rs[-1] += scores[ix].cpu()
all_rewards[ix] = rs

new_ppo_rl_elements = [
Expand All @@ -163,7 +244,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
)
for i in range(n)
]

ppo_rl_elements += new_ppo_rl_elements
exp_time = clock.tick()

Expand Down
Loading