Skip to content

Commit d283ee2

Browse files
committed
feat: support parallel reward function
1 parent e085ba2 commit d283ee2

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

trlx/trainer/accelerate_ppo_trainer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
351351
scores = all_scores
352352
scores_mask = scores != -np.inf
353353

354-
str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)
354+
if self.config.train.reward_only_in_main_process:
355+
str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)
356+
else:
357+
str_samples, str_prompts, str_outputs = all_str_samples, all_str_prompts, all_str_outputs
355358

356359
# Pad the sample outputs
357360
outputs = self.tokenizer(str_outputs).input_ids

0 commit comments

Comments
 (0)