Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/experimental/test_grpo_with_replay_buffer_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,9 @@ def test_update_with_inputs_different_seq_len(self):


@pytest.mark.low_priority
@pytest.mark.parametrize("scale_rewards", ["batch", "group"])
class TestGRPOWithReplayBufferTrainer(TrlTestCase):
def test_training_with_replay_buffer(self):
def test_training_with_replay_buffer(self, scale_rewards):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

# Guarantee that some rewards have 0 std
Expand All @@ -269,6 +270,7 @@ def custom_reward_func(completions, **kwargs):
max_completion_length=8, # reduce the completion length to reduce memory usage
replay_buffer_size=8,
report_to="none",
scale_rewards=scale_rewards,
)
trainer = GRPOWithReplayBufferTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,12 @@ def _generate_and_score_completions(
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = rewards - mean_grouped_rewards

grouped_std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
grouped_std_rewards = grouped_std_rewards.repeat_interleave(self.num_generations, dim=0)

if self.scale_rewards in ["group", "none"]:
# If self.scale_rewards = "none", we'll still log group level std
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
std_rewards = grouped_std_rewards.clone()
elif self.scale_rewards == "batch":
# Compute global std
std_rewards = rewards.std().expand_as(rewards)
Expand All @@ -261,7 +263,7 @@ def _generate_and_score_completions(
)
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
advantages = advantages[process_slice]
std_rewards = std_rewards[process_slice]
grouped_std_rewards = grouped_std_rewards[process_slice]

# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
Expand Down Expand Up @@ -316,7 +318,7 @@ def _generate_and_score_completions(
)
outputs_after_sampling_buffer = self.update_with_replay_buffer(
advantages,
std_rewards,
grouped_std_rewards,
prompt_ids,
prompt_mask,
completion_ids,
Expand Down
Loading