Skip to content

Commit 18450b0

Browse files
kfertakistjruwaseloadams
authored andcommitted
Εnable reward model offloading option (deepspeedai#930)
* enable reward model offloading option * fixed code formatting * more formatting fixes * Pre-commit formatting fix --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> Signed-off-by: zhangsmallshark <[email protected]>
1 parent cab3361 commit 18450b0

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -268,31 +268,22 @@ def _init_reward(self, critic_model_name_or_path):
268268
# If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
269269
zero_stage = 0
270270

271-
ds_config = get_eval_ds_config(offload=self.args.offload,
271+
ds_config = get_eval_ds_config(offload=self.args.offload_reward_model,
272272
dtype=self.args.dtype,
273273
stage=zero_stage)
274-
ds_config[
275-
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
276-
ds_config[
277-
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
278-
) * self.args.gradient_accumulation_steps
279-
280-
ds_eval_config = get_eval_ds_config(offload=False,
281-
dtype=self.args.dtype,
282-
stage=zero_stage)
283274

284275
# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
285-
ds_eval_config[
276+
ds_config[
286277
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
287-
ds_eval_config[
278+
ds_config[
288279
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
289280
) * self.args.gradient_accumulation_steps
290281

291282
# Model
292283
reward_model = create_critic_model(
293284
model_name_or_path=critic_model_name_or_path,
294285
tokenizer=self.tokenizer,
295-
ds_config=ds_eval_config,
286+
ds_config=ds_config,
296287
num_padding_at_beginning=self.args.num_padding_at_beginning,
297288
rlhf_training=True,
298289
dropout=self.args.critic_dropout,

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def parse_args():
246246
'--offload_reference_model',
247247
action='store_true',
248248
help='Enable ZeRO Offload techniques for reference model')
249+
parser.add_argument('--offload_reward_model',
250+
action='store_true',
251+
help='Enable ZeRO Offload techniques for reward model')
249252
parser.add_argument(
250253
'--actor_zero_stage',
251254
type=int,

0 commit comments

Comments
 (0)