File tree Expand file tree Collapse file tree 2 files changed +7
-13
lines changed
applications/DeepSpeed-Chat
training/step3_rlhf_finetuning Expand file tree Collapse file tree 2 files changed +7
-13
lines changed Original file line number Diff line number Diff line change @@ -268,31 +268,22 @@ def _init_reward(self, critic_model_name_or_path):
268268 # If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
269269 zero_stage = 0
270270
271- ds_config = get_eval_ds_config (offload = self .args .offload ,
271+ ds_config = get_eval_ds_config (offload = self .args .offload_reward_model ,
272272 dtype = self .args .dtype ,
273273 stage = zero_stage )
274- ds_config [
275- 'train_micro_batch_size_per_gpu' ] = self .args .per_device_training_batch_size
276- ds_config [
277- 'train_batch_size' ] = self .args .per_device_training_batch_size * torch .distributed .get_world_size (
278- ) * self .args .gradient_accumulation_steps
279-
280- ds_eval_config = get_eval_ds_config (offload = False ,
281- dtype = self .args .dtype ,
282- stage = zero_stage )
283274
284275 # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
285- ds_eval_config [
276+ ds_config [
286277 'train_micro_batch_size_per_gpu' ] = self .args .per_device_training_batch_size
287- ds_eval_config [
278+ ds_config [
288279 'train_batch_size' ] = self .args .per_device_training_batch_size * torch .distributed .get_world_size (
289280 ) * self .args .gradient_accumulation_steps
290281
291282 # Model
292283 reward_model = create_critic_model (
293284 model_name_or_path = critic_model_name_or_path ,
294285 tokenizer = self .tokenizer ,
295- ds_config = ds_eval_config ,
286+ ds_config = ds_config ,
296287 num_padding_at_beginning = self .args .num_padding_at_beginning ,
297288 rlhf_training = True ,
298289 dropout = self .args .critic_dropout ,
Original file line number Diff line number Diff line change @@ -246,6 +246,9 @@ def parse_args():
246246 '--offload_reference_model' ,
247247 action = 'store_true' ,
248248 help = 'Enable ZeRO Offload techniques for reference model' )
249+ parser .add_argument ('--offload_reward_model' ,
250+ action = 'store_true' ,
251+ help = 'Enable ZeRO Offload techniques for reward model' )
249252 parser .add_argument (
250253 '--actor_zero_stage' ,
251254 type = int ,
You can’t perform that action at this time.
0 commit comments