diff --git a/tests/experimental/test_grpo_with_replay_buffer_trainer.py b/tests/experimental/test_grpo_with_replay_buffer_trainer.py index 6ab0fdb2887..181f7204793 100644 --- a/tests/experimental/test_grpo_with_replay_buffer_trainer.py +++ b/tests/experimental/test_grpo_with_replay_buffer_trainer.py @@ -250,8 +250,9 @@ def test_update_with_inputs_different_seq_len(self): @pytest.mark.low_priority +@pytest.mark.parametrize("scale_rewards", ["batch", "group"]) class TestGRPOWithReplayBufferTrainer(TrlTestCase): - def test_training_with_replay_buffer(self): + def test_training_with_replay_buffer(self, scale_rewards): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") # Guarantee that some rewards have 0 std @@ -269,6 +270,7 @@ def custom_reward_func(completions, **kwargs): max_completion_length=8, # reduce the completion length to reduce memory usage replay_buffer_size=8, report_to="none", + scale_rewards=scale_rewards, ) trainer = GRPOWithReplayBufferTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 597d6218084..c95dbc18fac 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -238,10 +238,12 @@ def _generate_and_score_completions( mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = rewards - mean_grouped_rewards + grouped_std_rewards = rewards.view(-1, self.num_generations).std(dim=1) + grouped_std_rewards = grouped_std_rewards.repeat_interleave(self.num_generations, dim=0) + if self.scale_rewards in ["group", "none"]: # If self.scale_rewards = "none", we'll still log group level std - std_rewards = rewards.view(-1, self.num_generations).std(dim=1) - std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0) + std_rewards = grouped_std_rewards.clone() elif self.scale_rewards == "batch": # Compute global std std_rewards = rewards.std().expand_as(rewards) @@ -261,7 +263,7 @@ def _generate_and_score_completions( ) all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - std_rewards = std_rewards[process_slice] + grouped_std_rewards = grouped_std_rewards[process_slice] # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): @@ -316,7 +318,7 @@ def _generate_and_score_completions( ) outputs_after_sampling_buffer = self.update_with_replay_buffer( advantages, - std_rewards, + grouped_std_rewards, prompt_ids, prompt_mask, completion_ids,