-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minibatch impl #364
Minibatch impl #364
Changes from 7 commits
1a243ae
800c433
17b543a
a196418
17b12be
8994187
f389941
8485e78
d636448
833d049
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
import trlx.utils.logging as logging | ||
from trlx.data.configs import TRLConfig | ||
from trlx.data.ppo_types import PPORLBatch | ||
from trlx.trainer import BaseRLTrainer, register_trainer | ||
from trlx.utils import ( | ||
filter_non_scalars, | ||
|
@@ -45,6 +46,12 @@ class AccelerateRLTrainer(BaseRLTrainer): | |
def __init__(self, config, **kwargs): # noqa: C901 | ||
super().__init__(config, **kwargs) | ||
self.max_length = config.train.seq_length | ||
if config.train.minibatch_size: | ||
assert config.train.batch_size % config.train.minibatch_size == 0, "Minibatch size must divide batch size" | ||
self.mb_size = config.train.minibatch_size | ||
else: | ||
self.mb_size = config.train.batch_size | ||
self.num_mb = config.train.batch_size // self.mb_size | ||
self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir) | ||
|
||
if self.accelerator.state.deepspeed_plugin is not None: | ||
|
@@ -468,18 +475,40 @@ def learn(self): # noqa: C901 | |
for _ in range(self.config.train.epochs): | ||
# For each batch | ||
for batch in self.train_dataloader: | ||
mbs = [ | ||
PPORLBatch( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Dahoas Just to confirm, this is still a draft correct? Since (although I haven't run it) this I think would break for ILQLBatch and other types of datatypes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah you're right, I just ran this with the benchmarking scripts and it crashed here with this error There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed this here: #403 |
||
query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], | ||
response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], | ||
logprobs=batch.logprobs[mbi * self.mb_size : (mbi + 1) * self.mb_size], | ||
values=batch.values[mbi * self.mb_size : (mbi + 1) * self.mb_size], | ||
rewards=batch.rewards[mbi * self.mb_size : (mbi + 1) * self.mb_size], | ||
) | ||
for mbi in range(self.num_mb) | ||
] | ||
# For each update per batch | ||
for _ in range(self.n_updates_per_batch): | ||
# Note that whereas standard policy gradient methods perform one | ||
# gradient update per batch, PPO for example commonly performs | ||
# multiple gradient updates on the same batch of data. | ||
# https://arxiv.org/pdf/1707.06347.pdf | ||
forward_time = time() | ||
loss, stats = self.loss(batch) | ||
forward_time = time() - forward_time | ||
backward_time = time() | ||
self.accelerator.backward(loss) | ||
backward_time = time() - backward_time | ||
forward_time = 0 | ||
backward_time = 0 | ||
stats_accum = [] | ||
for mb in mbs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to avoid unnecessary gradient synchronization when doing using gradient accumulation you can simply add: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does that require setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fwiw i tried using that atop this PR and got weird results: https://wandb.ai/uwu1/trlx/reports/Untitled-Report--VmlldzozOTAyNDg4 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but we're specifying that already in the config files like zero2-bf16.yaml did you add it right below this line or literally at the top of the script? it should be added below this line since the entering of the context is how Also the division of the loss by self.num_mb should go away since that would be handled by accelerator Here's an example: https://github.com/muellerzr/timing_experiments/blob/main/good.py#L152 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eluzhnica want to make a PR adding that atop this PR? then can merge in after this one. It would be good to be able to specify it using the TRLConfig still vs having to use a seperate one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cat-state Okay just made a PR. I'll try to run the same tests that @Dahoas ran to confirm it works as intended. As for the TRLConfig, happy to do so, but it seems to me that the configs for accelerate are being set up separately in the repo (configs/accelerate/....) and that is the general pattern I've seen before too. And accelerate feeds those params behind the scenes for us automatically, so if we were to also specify in the TRLConfig it is a bit redundant (and potentially conflictual values). let me know what you think |
||
forward_time -= time() | ||
loss, stats = self.loss(mb) | ||
forward_time += time() | ||
loss /= self.num_mb | ||
backward_time -= time() | ||
self.accelerator.backward(loss) | ||
backward_time += time() | ||
stats_accum.append(stats) | ||
|
||
forward_time /= self.num_mb | ||
backward_time /= self.num_mb | ||
# TODO(Dahoas): Best way to combine stats between mbs? | ||
# How does accelerate do it? | ||
stats = {key: sum([stats[key] for stats in stats_accum]) / self.num_mb for key in stats_accum[0]} | ||
|
||
self.opt.step() | ||
self.opt.zero_grad() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(very pedantic) nit, feel free to ignore: usually I've heard this called micro batch, with minibatch referring to what we usually call "a batch" (to distinguish from a single batch of the whole dataset)