-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minibatch impl #364
Minibatch impl #364
Conversation
@@ -196,6 +196,9 @@ class TrainConfig: | |||
|
|||
:param seed: Random seed | |||
:type seed: int | |||
|
|||
:param minibatch_size: Size of model input during one forward pass. Must divide batch size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(very pedantic) nit, feel free to ignore: usually I've heard this called micro batch, with minibatch referring to what we usually call "a batch" (to distinguish from a single batch of the whole dataset)
stats_accum = [] | ||
for mbi in range(self.num_mb): | ||
forward_time -= time() | ||
loss, stats = self.loss(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this get loss using a minibatch sliced from batch
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah good catch (I copied this over from a different branch and forgot to change this)
forward_time = 0 | ||
backward_time = 0 | ||
stats_accum = [] | ||
for mb in mbs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to avoid unnecessary gradient synchronization when doing using gradient accumulation you can simply add: self.accelerator.accumulate(self.model)
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that require setting gradient_accumulation_steps
for the accelerator? cc @Dahoas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwiw i tried using that atop this PR and got weird results: https://wandb.ai/uwu1/trlx/reports/Untitled-Report--VmlldzozOTAyNDg4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, but we're specifying that already in the config files like zero2-bf16.yaml
did you add it right below this line or literally at the top of the script? it should be added below this line since the entering of the context is how accumulate
tracks the number of steps that are being executed so that it knows when to sync (has an internal step counter)
Also the division of the loss by self.num_mb should go away since that would be handled by accelerator
Here's an example: https://github.com/muellerzr/timing_experiments/blob/main/good.py#L152
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eluzhnica want to make a PR adding that atop this PR? then can merge in after this one. It would be good to be able to specify it using the TRLConfig still vs having to use a seperate one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cat-state Okay just made a PR. I'll try to run the same tests that @Dahoas ran to confirm it works as intended.
As for the TRLConfig, happy to do so, but it seems to me that the configs for accelerate are being set up separately in the repo (configs/accelerate/....) and that is the general pattern I've seen before too. And accelerate feeds those params behind the scenes for us automatically, so if we were to also specify in the TRLConfig it is a bit redundant (and potentially conflictual values). let me know what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -468,18 +475,40 @@ def learn(self): # noqa: C901 | |||
for _ in range(self.config.train.epochs): | |||
# For each batch | |||
for batch in self.train_dataloader: | |||
mbs = [ | |||
PPORLBatch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Dahoas Just to confirm, this is still a draft correct? Since (although I haven't run it) this I think would break for ILQLBatch and other types of datatypes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you're right, I just ran this with the benchmarking scripts and it crashed here with this error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed this here: #403
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to make it not incompatible with ILQL
* Add minibatch iterator * Add tests
Merged #403 which makes mini-batching compatible with all trainers. Thank you @eluzhnica! |
Let's merge this into main once #396 gets merged. |
* Avoid gradient synchronization when accumulating * Fix accumulation to account for dataloader * Add some tests
@cat-state Can you take another look and if it looks good we can merge? |
thanks @eluzhnica and @Dahoas ! just tried w ilql and it seems to work https://wandb.ai/uwu1/trlx/runs/m2e4rwga?workspace=user-uwu1 |
Implements minibatching for PPO.
PPO sentiments
bs: 32, mbs: 16
. https://wandb.ai/dahoas/trlx/runs/oo6t8rla/overview?workspace=user-dahoasPPO HH on GPT-NeoX
bs: 4, mbs: 1
. https://wandb.ai/dahoas/trlx/runs/9rnbmtu6?workspace=user-dahoas