Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9c977d0
working v1
younesbelkada Dec 27, 2022
1971cea
add `accelerate` on requirements
younesbelkada Dec 27, 2022
45cad09
add `accelerate` on `setup.py`
younesbelkada Dec 27, 2022
a0ebdaa
add `datasets` on `setup.py`
younesbelkada Dec 27, 2022
dec21f3
small updates
younesbelkada Dec 27, 2022
4254292
rm unneeded file
younesbelkada Dec 27, 2022
19f4d92
replace with `generate`
younesbelkada Dec 27, 2022
35330a9
Update trl/trainer/accelerate_ppo.py
younesbelkada Dec 27, 2022
34773de
correct return
younesbelkada Dec 27, 2022
b810d8a
add dataloader support
younesbelkada Dec 27, 2022
e4c57b2
add `wandb` to `setup.py`
younesbelkada Dec 27, 2022
7516b37
refactor
younesbelkada Dec 27, 2022
40f81e0
test
younesbelkada Dec 27, 2022
b1638e5
fix test
younesbelkada Dec 27, 2022
e2e7a90
rename file
younesbelkada Dec 27, 2022
96b4115
refactor
younesbelkada Dec 27, 2022
5eb46ad
remove unneeded device assignment
younesbelkada Dec 27, 2022
609f718
fix correct device assignment
younesbelkada Dec 27, 2022
4d57b47
standardize docstrings
younesbelkada Dec 27, 2022
fac85b5
add `wandb` on `dev`
younesbelkada Dec 27, 2022
c1b166b
fix slow convergence
younesbelkada Dec 28, 2022
9495f2a
oops
younesbelkada Dec 28, 2022
c813857
revert fix
younesbelkada Dec 28, 2022
157eca6
revert patch
younesbelkada Dec 28, 2022
2efb961
Merge remote-tracking branch 'origin/master' into accelerate-ppo
younesbelkada Dec 28, 2022
0a1c9a2
remove unneeded reshape
younesbelkada Dec 28, 2022
b6004f0
add input safety checker
younesbelkada Dec 28, 2022
f47b907
refactor
younesbelkada Dec 28, 2022
2918a8e
Apply suggestions from code review
younesbelkada Dec 29, 2022
747d5f0
refactor
younesbelkada Dec 29, 2022
7615994
some refactor
younesbelkada Dec 29, 2022
65be5bd
remove unneeded hack
younesbelkada Dec 29, 2022
edd5ea3
adapt dataset
younesbelkada Dec 29, 2022
76c2afd
fix test
younesbelkada Dec 29, 2022
5d41170
remove rollout
younesbelkada Dec 29, 2022
7843a34
remove timing
younesbelkada Dec 29, 2022
6cd89d5
remove `shuffle=True`
younesbelkada Dec 29, 2022
4e802e8
remove `LengthSampler` from trainer
younesbelkada Dec 29, 2022
6012a9b
refactor
younesbelkada Dec 29, 2022
d2c363f
remove text length sampler args from config
younesbelkada Dec 29, 2022
d048bbe
change collate_fn
younesbelkada Dec 29, 2022
66f23b1
fix silent bug
younesbelkada Dec 29, 2022
e318307
rename
younesbelkada Dec 29, 2022
31d12d6
move file
younesbelkada Dec 29, 2022
48c1070
refactor base trainer
younesbelkada Dec 29, 2022
e9cec71
fix collate
younesbelkada Dec 29, 2022
9a987d4
Merge remote-tracking branch 'origin/master' into accelerate-ppo
younesbelkada Dec 29, 2022
244f001
final bug
younesbelkada Dec 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions examples/scripts/04-ppo-sentiment-accelerate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import time
from tqdm import tqdm
import numpy as np
tqdm.pandas()

from transformers import pipeline

from trl import AcceleratePPOTrainer

config = {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: use data classes instead of dicts for the config (easier to refactor in future) like we do in transformers: https://github.com/huggingface/transformers/blob/bbcd961897aa6cc439ef4cca5cef6db4283c5b76/examples/pytorch/text-classification/run_glue.py#L70

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a simple dataclass for now: 747d5f0
maybe we can refactor as it is done in transformers as a follow up PR!

"model_name": "lvwerra/gpt2-imdb",
# "model_name": "facebook/opt-350m",
"cls_model_name": "lvwerra/distilbert-imdb",
"steps": 20000,
"batch_size": 128,
"forward_batch_size": 16,
"ppo_epochs": 4,
"txt_in_min_len": 2,
"txt_in_max_len": 8,
"txt_out_min_len": 4,
"txt_out_max_len": 16,
"lr": 1.41e-5,
"init_kl_coef":0.2,
"target": 6,
"horizon":10000,
"gamma":1,
"lam":0.95,
"cliprange": .2,
"cliprange_value":.2,
"vf_coef":.1,
}

sent_kwargs = {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re data classes

"return_all_scores": True,
"function_to_apply": "none",
"batch_size": config["forward_batch_size"]
}

ppo_trainer = AcceleratePPOTrainer(**config)
tokenizer = ppo_trainer.tokenizer

device = ppo_trainer.accelerator.device
if device.index is None:
# single GPU - maybe introduce this hack inside AcceleratePPOTrainer?
device = 0
sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=device)


gen_kwargs = {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada younesbelkada Dec 29, 2022

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!
However I tried to run generate with GenerationConfig and didn't managed to make it work since it seems that the feature is currently only available on the main branch. This PR: huggingface/transformers#20388 has been merged 2 weeks ago
So maybe let's address this in a follow up PR !

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's not commit to very new features, otherwise we need very hard transformers dependency

"min_length":-1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id
}

total_ppo_epochs = int(np.ceil(config["steps"]/config['batch_size']))

for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(ppo_trainer.dataloader))):
logs, timing = dict(), dict()
t0 = time.time()
query_tensors = [torch.tensor(t).long().to(device) for t in batch["tokens"]]

#### Get response from gpt2
t = time.time()
response_tensors = ppo_trainer.get_response(query_tensors, **gen_kwargs)
batch['response'] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
timing['time/get_response'] = time.time()-t

#### Compute sentiment score
t = time.time()
texts = [q + r for q,r in zip(batch['query'], batch['response'])]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the remove columns method inside the trainer the query shouldn't be there anymore? since we don't pass the data through the model internally, we don't need to remove the columns?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device)
timing['time/get_sentiment_preds'] = time.time()-t

#### Run PPO step
t = time.time()
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, timing, batch, rewards, t0, t, logs)

@younesbelkada younesbelkada Dec 27, 2022

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To improve, we probably want a better way to log the stats


1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ datasets==1.17.0
torch>=1.4.0
tqdm
transformers
accelerate
wandb==0.10.20
matplotlib==3.5.1
3 changes: 2 additions & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__version__ = "0.1.1"

from .models import AutoModelForCausalLMWithValueHead
from .models import AutoModelForCausalLMWithValueHead
from .trainer import AcceleratePPOTrainer
3 changes: 3 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseTrainer
from .utils import AdaptiveKLController, FixedKLController, LengthSampler
from .accelerate_ppo import AcceleratePPOTrainer
Loading