Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix distributed dataloaders & deduplicate eval #276

Merged
merged 2 commits into from
Feb 4, 2023
Merged

Conversation

maxreciprocate
Copy link
Collaborator

@maxreciprocate maxreciprocate commented Feb 4, 2023

  • Shards eval_pipeline (accelerate.prepare(dataloader) is a noop with deepspeed) to split evaluation samples among processes
  • Trims redundant evaluation samples in case their number is not divisible by the number of processes multiplied by batch size (depends on recent fixes in the newest accelerate release 0.16.0)
  • Shards prompt pipeline in PPOOrchestrator in the case of deepspeed (previously was sharded only for ddp)
  • Unshards training pipeline in the case of ddp (previously only part of rollouts were used on each process during training)

https://wandb.ai/sorry/trlx/reports/Fix-distributed-dataloaders-deduplicate-eval-276---VmlldzozNDgzNDU3

These changes influenced training with both deepspeed & ddp by no more than correcting effective batch_size and num_rollouts. The number of evaluation samples now is the same as the number of passed eval_prompts (this is a fix of #247)

ppo_sentiments.py @ num_processes=4, batch_size=16, num_rollouts=128, len(eval_prompts)=1024

main & ddp
len(train_dataloader)=2
len(eval_dataloader)=16

fix & ddp
len(train_dataloader)=8
len(eval_dataloader)=16

main & deepspeed
len(train_dataloader)=8
len(eval_dataloader)=64

fix & deepspeed
len(train_dataloader)=8
len(eval_dataloader)=16

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Just left a tiny request but feel free to leave it for another PR.

Aside (while on the topic of dataloaders): Do you know what the purpose of the rollout_loader in the lines below is for? Seems like it was used to satisfy prepare with deepspeed before it became a no-op

rollout_loader = self.store.create_loader(self.config.train.batch_size, shuffle=True)
self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare(
self.model, self.opt, self.scheduler, rollout_loader
)

Comment on lines +356 to +358
samples = self.accelerator.gather_for_metrics(torch.vstack(all_samples))
prompts = self.accelerator.gather_for_metrics(torch.vstack(all_prompts))
prompt_sizes = self.accelerator.gather_for_metrics(torch.hstack(prompt_sizes))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should bump the accelerate version in our dependency list to the accelerate>=0.16.0 fix?

trlx/setup.cfg

Line 14 in 070c58f

accelerate>=0.12.0

(This doesn't affect training so maybe not worth the constraint now?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constraint would be benign anyway since nothing else depends on accelerate, however we have to check whether there isn't something broken with this new release to be deliberate with the version here, so probably let's add it later

@maxreciprocate
Copy link
Collaborator Author

Just to clarify accelerate.prepare is not a noop when all elements are passed in conjunction, only when used on sole dataloaders after the model & optimizer have been "prepared". The only function rollout_loader served was to propagate batch_size to the deepspeed's config, which can be done manually, but I would leave it as to not break something else accidentally

@jon-tow
Copy link
Collaborator

jon-tow commented Feb 4, 2023

Ah gotcha - I misunderstood this discussion. Thanks for clarifying. Merging!

@jon-tow jon-tow merged commit de8df0f into main Feb 4, 2023
@jon-tow jon-tow deleted the fix-ppo-dataloaders branch February 4, 2023 21:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants