Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minibatch impl #364

Merged
merged 10 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class TrainConfig:

:param seed: Random seed
:type seed: int

:param minibatch_size: Size of model input during one forward pass. Must divide batch size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(very pedantic) nit, feel free to ignore: usually I've heard this called micro batch, with minibatch referring to what we usually call "a batch" (to distinguish from a single batch of the whole dataset)

:type minibatch_size: int
"""

total_steps: int
Expand Down Expand Up @@ -223,6 +226,8 @@ class TrainConfig:

seed: int = 1000

minibatch_size: Optional[int] = None

@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
Expand Down
41 changes: 35 additions & 6 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import trlx.utils.logging as logging
from trlx.data.configs import TRLConfig
from trlx.data.ppo_types import PPORLBatch
from trlx.trainer import BaseRLTrainer, register_trainer
from trlx.utils import (
filter_non_scalars,
Expand Down Expand Up @@ -45,6 +46,12 @@ class AccelerateRLTrainer(BaseRLTrainer):
def __init__(self, config, **kwargs): # noqa: C901
super().__init__(config, **kwargs)
self.max_length = config.train.seq_length
if config.train.minibatch_size:
assert config.train.batch_size % config.train.minibatch_size == 0, "Minibatch size must divide batch size"
self.mb_size = config.train.minibatch_size
else:
self.mb_size = config.train.batch_size
self.num_mb = config.train.batch_size // self.mb_size
self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir)

if self.accelerator.state.deepspeed_plugin is not None:
Expand Down Expand Up @@ -468,18 +475,40 @@ def learn(self): # noqa: C901
for _ in range(self.config.train.epochs):
# For each batch
for batch in self.train_dataloader:
mbs = [
PPORLBatch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dahoas Just to confirm, this is still a draft correct? Since (although I haven't run it) this I think would break for ILQLBatch and other types of datatypes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right, I just ran this with the benchmarking scripts and it crashed here with this error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this here: #403

query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size],
response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size],
logprobs=batch.logprobs[mbi * self.mb_size : (mbi + 1) * self.mb_size],
values=batch.values[mbi * self.mb_size : (mbi + 1) * self.mb_size],
rewards=batch.rewards[mbi * self.mb_size : (mbi + 1) * self.mb_size],
)
for mbi in range(self.num_mb)
]
# For each update per batch
for _ in range(self.n_updates_per_batch):
# Note that whereas standard policy gradient methods perform one
# gradient update per batch, PPO for example commonly performs
# multiple gradient updates on the same batch of data.
# https://arxiv.org/pdf/1707.06347.pdf
forward_time = time()
loss, stats = self.loss(batch)
forward_time = time() - forward_time
backward_time = time()
self.accelerator.backward(loss)
backward_time = time() - backward_time
forward_time = 0
backward_time = 0
stats_accum = []
for mb in mbs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid unnecessary gradient synchronization when doing using gradient accumulation you can simply add: self.accelerator.accumulate(self.model) here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that require setting gradient_accumulation_steps for the accelerator? cc @Dahoas

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw i tried using that atop this PR and got weird results: https://wandb.ai/uwu1/trlx/reports/Untitled-Report--VmlldzozOTAyNDg4

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but we're specifying that already in the config files like zero2-bf16.yaml

did you add it right below this line or literally at the top of the script? it should be added below this line since the entering of the context is how accumulate tracks the number of steps that are being executed so that it knows when to sync (has an internal step counter)

Also the division of the loss by self.num_mb should go away since that would be handled by accelerator

Here's an example: https://github.com/muellerzr/timing_experiments/blob/main/good.py#L152

Copy link
Collaborator

@cat-state cat-state Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eluzhnica want to make a PR adding that atop this PR? then can merge in after this one. It would be good to be able to specify it using the TRLConfig still vs having to use a seperate one

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cat-state Okay just made a PR. I'll try to run the same tests that @Dahoas ran to confirm it works as intended.

As for the TRLConfig, happy to do so, but it seems to me that the configs for accelerate are being set up separately in the repo (configs/accelerate/....) and that is the general pattern I've seen before too. And accelerate feeds those params behind the scenes for us automatically, so if we were to also specify in the TRLConfig it is a bit redundant (and potentially conflictual values). let me know what you think

forward_time -= time()
loss, stats = self.loss(mb)
forward_time += time()
loss /= self.num_mb
backward_time -= time()
self.accelerator.backward(loss)
backward_time += time()
stats_accum.append(stats)

forward_time /= self.num_mb
backward_time /= self.num_mb
# TODO(Dahoas): Best way to combine stats between mbs?
# How does accelerate do it?
stats = {key: sum([stats[key] for stats in stats_accum]) / self.num_mb for key in stats_accum[0]}

self.opt.step()
self.opt.zero_grad()
Expand Down
18 changes: 10 additions & 8 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])

n_samples: int = samples.shape[0]
logprobs = logprobs.cpu()
ref_logprobs = ref_logprobs.cpu()
prompt_tensors = prompt_tensors.cpu()
sample_outputs = sample_outputs.cpu()
values = values.cpu()[:, :-1]

# Estimate the KL divergence between the model and reference model
if self.config.model.model_arch_type == "seq2seq":
Expand All @@ -447,16 +442,23 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
else:
start = prompt_tensors.shape[1] - 1

log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]
self.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device)

logprobs = logprobs.cpu()
ref_logprobs = ref_logprobs.cpu()
prompt_tensors = prompt_tensors.cpu()
sample_outputs = sample_outputs.cpu()
values = values.cpu()[:, :-1]

ends = start + attention_mask[:, start:].sum(1)

# Get the logprobs and values, for tokens that are not padding
# or beginning of sequences tokens. These are from the model (not the reference model)
all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]

log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1].cpu()
self.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device)
kl_penalty = self.kl_ctl.value * -log_ratio
kl_penalty = self.kl_ctl.value * -log_ratio.cpu()
kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)]

rollout_count = 0
Expand Down