Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minibatch impl #364

Merged
merged 10 commits into from
Apr 6, 2023
Merged

Minibatch impl #364

merged 10 commits into from
Apr 6, 2023

Conversation

Dahoas
Copy link
Collaborator

@Dahoas Dahoas commented Mar 13, 2023

Implements minibatching for PPO.

PPO sentiments bs: 32, mbs: 16. https://wandb.ai/dahoas/trlx/runs/oo6t8rla/overview?workspace=user-dahoas

PPO HH on GPT-NeoX bs: 4, mbs: 1. https://wandb.ai/dahoas/trlx/runs/9rnbmtu6?workspace=user-dahoas

@@ -196,6 +196,9 @@ class TrainConfig:

:param seed: Random seed
:type seed: int

:param minibatch_size: Size of model input during one forward pass. Must divide batch size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(very pedantic) nit, feel free to ignore: usually I've heard this called micro batch, with minibatch referring to what we usually call "a batch" (to distinguish from a single batch of the whole dataset)

stats_accum = []
for mbi in range(self.num_mb):
forward_time -= time()
loss, stats = self.loss(batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this get loss using a minibatch sliced from batch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah good catch (I copied this over from a different branch and forgot to change this)

forward_time = 0
backward_time = 0
stats_accum = []
for mb in mbs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid unnecessary gradient synchronization when doing using gradient accumulation you can simply add: self.accelerator.accumulate(self.model) here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that require setting gradient_accumulation_steps for the accelerator? cc @Dahoas

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw i tried using that atop this PR and got weird results: https://wandb.ai/uwu1/trlx/reports/Untitled-Report--VmlldzozOTAyNDg4

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but we're specifying that already in the config files like zero2-bf16.yaml

did you add it right below this line or literally at the top of the script? it should be added below this line since the entering of the context is how accumulate tracks the number of steps that are being executed so that it knows when to sync (has an internal step counter)

Also the division of the loss by self.num_mb should go away since that would be handled by accelerator

Here's an example: https://github.com/muellerzr/timing_experiments/blob/main/good.py#L152

Copy link
Collaborator

@cat-state cat-state Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eluzhnica want to make a PR adding that atop this PR? then can merge in after this one. It would be good to be able to specify it using the TRLConfig still vs having to use a seperate one

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cat-state Okay just made a PR. I'll try to run the same tests that @Dahoas ran to confirm it works as intended.

As for the TRLConfig, happy to do so, but it seems to me that the configs for accelerate are being set up separately in the repo (configs/accelerate/....) and that is the general pattern I've seen before too. And accelerate feeds those params behind the scenes for us automatically, so if we were to also specify in the TRLConfig it is a bit redundant (and potentially conflictual values). let me know what you think

Copy link
Collaborator

@cat-state cat-state left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -468,18 +475,40 @@ def learn(self): # noqa: C901
for _ in range(self.config.train.epochs):
# For each batch
for batch in self.train_dataloader:
mbs = [
PPORLBatch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dahoas Just to confirm, this is still a draft correct? Since (although I haven't run it) this I think would break for ILQLBatch and other types of datatypes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right, I just ran this with the benchmarking scripts and it crashed here with this error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this here: #403

Copy link
Collaborator

@cat-state cat-state left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to make it not incompatible with ILQL

@Dahoas Dahoas requested a review from cat-state April 3, 2023 18:43
@Dahoas
Copy link
Collaborator Author

Dahoas commented Apr 3, 2023

Merged #403 which makes mini-batching compatible with all trainers. Thank you @eluzhnica!

@Dahoas
Copy link
Collaborator Author

Dahoas commented Apr 3, 2023

Let's merge this into main once #396 gets merged.

* Avoid gradient synchronization when accumulating

* Fix accumulation to account for dataloader

* Add some tests
@Dahoas
Copy link
Collaborator Author

Dahoas commented Apr 5, 2023

@cat-state Can you take another look and if it looks good we can merge?

@cat-state
Copy link
Collaborator

thanks @eluzhnica and @Dahoas ! just tried w ilql and it seems to work https://wandb.ai/uwu1/trlx/runs/m2e4rwga?workspace=user-uwu1
Thanks for working together and adding this!

@cat-state cat-state merged commit 565c316 into main Apr 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants