-
Notifications
You must be signed in to change notification settings - Fork 2.8k
accelerate integration
#58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
9c977d0
1971cea
45cad09
a0ebdaa
dec21f3
4254292
19f4d92
35330a9
34773de
b810d8a
e4c57b2
7516b37
40f81e0
b1638e5
e2e7a90
96b4115
5eb46ad
609f718
4d57b47
fac85b5
c1b166b
9495f2a
c813857
157eca6
2efb961
0a1c9a2
b6004f0
f47b907
2918a8e
747d5f0
7615994
65be5bd
edd5ea3
76c2afd
5d41170
7843a34
6cd89d5
4e802e8
6012a9b
d2c363f
d048bbe
66f23b1
e318307
31d12d6
48c1070
e9cec71
9a987d4
244f001
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| import torch | ||
| import time | ||
| from tqdm import tqdm | ||
| import numpy as np | ||
| tqdm.pandas() | ||
|
|
||
| from transformers import pipeline | ||
|
|
||
| from trl import AcceleratePPOTrainer | ||
|
|
||
| config = { | ||
| "model_name": "lvwerra/gpt2-imdb", | ||
| # "model_name": "facebook/opt-350m", | ||
| "cls_model_name": "lvwerra/distilbert-imdb", | ||
| "steps": 20000, | ||
| "batch_size": 128, | ||
| "forward_batch_size": 16, | ||
| "ppo_epochs": 4, | ||
| "txt_in_min_len": 2, | ||
| "txt_in_max_len": 8, | ||
| "txt_out_min_len": 4, | ||
| "txt_out_max_len": 16, | ||
| "lr": 1.41e-5, | ||
| "init_kl_coef":0.2, | ||
| "target": 6, | ||
| "horizon":10000, | ||
| "gamma":1, | ||
| "lam":0.95, | ||
| "cliprange": .2, | ||
| "cliprange_value":.2, | ||
| "vf_coef":.1, | ||
| } | ||
|
|
||
| sent_kwargs = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here re data classes |
||
| "return_all_scores": True, | ||
| "function_to_apply": "none", | ||
| "batch_size": config["forward_batch_size"] | ||
| } | ||
|
|
||
| ppo_trainer = AcceleratePPOTrainer(**config) | ||
| tokenizer = ppo_trainer.tokenizer | ||
|
|
||
| device = ppo_trainer.accelerator.device | ||
| if device.index is None: | ||
| # single GPU - maybe introduce this hack inside AcceleratePPOTrainer? | ||
| device = 0 | ||
| sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=device) | ||
|
|
||
|
|
||
| gen_kwargs = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be cleaner to use the new
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed!
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, let's not commit to very new features, otherwise we need very hard |
||
| "min_length":-1, | ||
| "top_k": 0.0, | ||
| "top_p": 1.0, | ||
| "do_sample": True, | ||
| "pad_token_id": tokenizer.eos_token_id | ||
| } | ||
|
|
||
| total_ppo_epochs = int(np.ceil(config["steps"]/config['batch_size'])) | ||
|
|
||
| for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(ppo_trainer.dataloader))): | ||
| logs, timing = dict(), dict() | ||
| t0 = time.time() | ||
| query_tensors = [torch.tensor(t).long().to(device) for t in batch["tokens"]] | ||
|
|
||
| #### Get response from gpt2 | ||
| t = time.time() | ||
| response_tensors = ppo_trainer.get_response(query_tensors, **gen_kwargs) | ||
| batch['response'] = [tokenizer.decode(r.squeeze()) for r in response_tensors] | ||
| timing['time/get_response'] = time.time()-t | ||
|
|
||
| #### Compute sentiment score | ||
| t = time.time() | ||
| texts = [q + r for q,r in zip(batch['query'], batch['response'])] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with the remove columns method inside the trainer the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The query are kept here, https://github.com/younesbelkada/trl/blob/d2c363fe4018c74df829ed6c067fad50ecaaf479/trl/trainer/ppo_trainer.py#L152 but maybe we can change that, wdyt? |
||
| pipe_outputs = sentiment_pipe(texts, **sent_kwargs) | ||
| rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device) | ||
| timing['time/get_sentiment_preds'] = time.time()-t | ||
|
|
||
| #### Run PPO step | ||
| t = time.time() | ||
| stats = ppo_trainer.step(query_tensors, response_tensors, rewards) | ||
| ppo_trainer.log_stats(stats, timing, batch, rewards, t0, t, logs) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve, we probably want a better way to log the stats |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,5 +5,6 @@ datasets==1.17.0 | |
| torch>=1.4.0 | ||
| tqdm | ||
| transformers | ||
| accelerate | ||
| wandb==0.10.20 | ||
| matplotlib==3.5.1 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| __version__ = "0.1.1" | ||
|
|
||
| from .models import AutoModelForCausalLMWithValueHead | ||
| from .models import AutoModelForCausalLMWithValueHead | ||
| from .trainer import AcceleratePPOTrainer |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .base import BaseTrainer | ||
| from .utils import AdaptiveKLController, FixedKLController, LengthSampler | ||
| from .accelerate_ppo import AcceleratePPOTrainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: use data classes instead of dicts for the config (easier to refactor in future) like we do in
transformers: https://github.com/huggingface/transformers/blob/bbcd961897aa6cc439ef4cca5cef6db4283c5b76/examples/pytorch/text-classification/run_glue.py#L70There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a simple dataclass for now: 747d5f0
maybe we can refactor as it is done in
transformersas a follow up PR!