-
Notifications
You must be signed in to change notification settings - Fork 2.4k
accelerate integration
#58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
accelerate integration
#58
Conversation
| #### Run PPO step | ||
| t = time.time() | ||
| stats = ppo_trainer.step(query_tensors, response_tensors, rewards) | ||
| ppo_trainer.log_stats(stats, timing, batch, rewards, t0, t, logs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve, we probably want a better way to log the stats
trl/trainer/accelerate_ppo.py
Outdated
| if isinstance(v, torch.Tensor) and k != 'objective/kl': | ||
| # tensor_list = [torch.zeros_like(v) for _ in range(self.accelerator.num_processes)] | ||
| dist.all_reduce(v, dist.ReduceOp.SUM) | ||
| v /= self.accelerator.num_processes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me in a DP setup, each GPU will need to have its own replica of objective/kl since this is used to update the kl_ctl object above. That is why I prefered to not include it in the all_reduce operation but I just wanted to confirm
- add docstring on most functions - correct logging
lvwerra
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this @younesbelkada. My main comments are about DP. I think if we don't wrap the step inputs (queries/responses) in a dataloader we don't achieve proper DP. But maybe I am wrong?
trl/trainer/accelerate_ppo.py
Outdated
| model (torch.model): Hugging Face transformer GPT2 model with value head | ||
| ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty | ||
| tokenizer (tokenizer): Hugging Face tokenizer | ||
| ppo_params (dict or None): PPO parameters for training. Can include following keys: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should replace **config (=ppo_params) with explicit kwargs or setup TrainingArguments like in transformers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be a follow up PR btw
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
- random init seems to converge much faster
|
wandb run (multi-GPU) after the latest commit: https://wandb.ai/distill-bloom/trl/runs/1mps4h09?workspace=user-younesbelkada |
lvwerra
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are pretty close - a few open questions and minor changs :)
| stats (dict[str, Any]): | ||
| a dictionary of stats with the tensors gathered. | ||
| """ | ||
| import torch.distributed as dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think?
| # In a distributed setup, only logging needs to be performed on the main process | ||
| # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html | ||
| # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 | ||
| self.is_distributed = self.accelerator.distributed_type == "MULTI_GPU" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can use accelerates gather method we can probably get rid of this?
|
|
||
| #### Compute sentiment score | ||
| t = time.time() | ||
| texts = [q + r for q,r in zip(batch['query'], batch['response'])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the remove columns method inside the trainer the query shouldn't be there anymore? since we don't pass the data through the model internally, we don't need to remove the columns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The query are kept here, https://github.com/younesbelkada/trl/blob/d2c363fe4018c74df829ed6c067fad50ecaaf479/trl/trainer/ppo_trainer.py#L152 but maybe we can change that, wdyt?
|
Wandb log of the final run: https://wandb.ai/distill-bloom/trl/runs/dcd2gqn1?workspace=user-younesbelkada |
|
The documentation is not available anymore as the PR was closed or merged. |
* working v1 * add `accelerate` on requirements * add `accelerate` on `setup.py` * add `datasets` on `setup.py` * small updates - add docstring on most functions - correct logging * rm unneeded file * replace with `generate` * Update trl/trainer/accelerate_ppo.py Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * correct return * add dataloader support * add `wandb` to `setup.py` * refactor - remove `_build_dataset` method - change name to `PPOTrainer` * test * fix test * rename file * refactor * remove unneeded device assignment * fix correct device assignment * standardize docstrings * add `wandb` on `dev` * fix slow convergence - random init seems to converge much faster * oops * revert fix * revert patch * remove unneeded reshape * add input safety checker * refactor - added comments on example - fixes CI test - rewards should be a list of tensors - clearer error messages - remove build model method - refactor log stats method Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * refactor - added `PPOConfig` class - docstring on `LengthSampler` - fix test - gather rewards when logging - unwrap model when calling generate * some refactor * remove unneeded hack * adapt dataset * fix test * remove rollout * remove timing * remove `shuffle=True` * remove `LengthSampler` from trainer * refactor * remove text length sampler args from config * change collate_fn * fix silent bug * rename * move file * refactor base trainer * fix collate * final bug Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
What does this PR do?
This PR integrates
trlwithaccelerateto make it compatible with the tools provided by the library to be able to train models usingPPOTrainer. This would enable users to train their models in mixed precision, using Data Parallelism etc in a very simple manner.Users should design their own training script and run them using
accelerate launch xxx.pybased on the example scripts provided inexamples/scripits.This PR also integrates Data Parallelism paradigm, enabling users to benefit from multi-GPU training if they want to speedup training.
TODOs
accelerateexamples)DeepSpeed tests (check where it works)