diff --git a/tests/algorithm/advantage_fn_test.py b/tests/algorithm/advantage_fn_test.py index 5d03ca55c1..b9f8dab858 100644 --- a/tests/algorithm/advantage_fn_test.py +++ b/tests/algorithm/advantage_fn_test.py @@ -326,3 +326,71 @@ def test_batch_level_step_wise_grpo_advantage(self): expected_advantages = expected_advantage_value * target_exp.action_mask self.assertTrue(torch.allclose(target_exp.advantages, expected_advantages, atol=1e-6)) self.assertTrue(torch.allclose(target_exp.returns, expected_advantages, atol=1e-6)) + + def test_step_wise_grpo_with_std_threshold(self): + advantage_fn_cls = ADVANTAGE_FN.get("step_wise_grpo") + self.assertIsNotNone(advantage_fn_cls) + advantage_fn = advantage_fn_cls(epsilon=1e-6, std_threshold=0.0001) + repeat_times = 5 + step_num = 4 + + # Create experiences with mixed reward patterns: + # - task 0: all runs have same reward (0.5) -> should be filtered + # - task 1: all runs have same reward (1.0) -> should be filtered + # - task 2: runs have different rewards (0, 1, 2, 3, 4) -> should NOT be filtered + exps = [] + + # Task 0: constant reward 0.5 + for k in range(step_num): + for i in range(repeat_times): + exps.append( + Experience( + eid=EID(batch=0, task=0, run=i, step=k), + tokens=torch.zeros(5), + prompt_length=2, + reward=0.5, + ) + ) + + # Task 1: constant reward 1.0 + for k in range(step_num): + for i in range(repeat_times): + exps.append( + Experience( + eid=EID(batch=0, task=1, run=i, step=k), + tokens=torch.zeros(5), + prompt_length=2, + reward=1.0, + ) + ) + + # Task 2: varying rewards + for k in range(step_num): + for i in range(repeat_times): + exps.append( + Experience( + eid=EID(batch=0, task=2, run=i, step=k), + tokens=torch.zeros(5), + prompt_length=2, + reward=float(i), + ) + ) + + processed_exps, metrics = advantage_fn(exps) + + # Only task 2 should remain (task 0 and task 1 filtered due to zero std) + expected_remaining = repeat_times * step_num # task 2 only + expected_filtered = 2 * repeat_times * step_num # task 0 and task 1 + + self.assertEqual(len(processed_exps), expected_remaining) + self.assertIn("filtered_count", metrics) + self.assertEqual(metrics["filtered_count"], expected_filtered) + + # Verify skipped group ratio: 2 out of 3 tasks were skipped + self.assertIn("skipped_group_ratio", metrics) + expected_ratio = 2.0 / 3.0 # task 0 and task 1 skipped out of 3 total tasks + self.assertAlmostEqual(metrics["skipped_group_ratio"], expected_ratio, places=6) + + # Verify that all remaining experiences are from task 2 + for exp in processed_exps: + self.assertEqual(exp.eid.task, 2) diff --git a/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py b/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py index a085ee1ae2..3c11daf203 100644 --- a/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py @@ -22,6 +22,7 @@ def __init__( epsilon: float = 1e-6, enable_step_norm: bool = False, std_cal_level: str = "group", # 'group' (task-level) or 'batch' + std_threshold: Optional[float] = None, **kwargs, ) -> None: """Initialize the Step-wise GRPO advantage function. @@ -33,10 +34,13 @@ def __init__( 'group' (default): Std is calculated per task group. 'batch': Std is calculated across all last-step rewards in the entire batch. The mean is always calculated per task group. + std_threshold (Optional[float]): If provided, task groups with a reward standard deviation + equal or below this threshold will be skipped. """ self.epsilon = epsilon self.enable_step_norm = enable_step_norm self.std_cal_level = std_cal_level + self.std_threshold = std_threshold if self.std_cal_level not in ["group", "batch"]: raise ValueError("std_cal_level must be either 'group' or 'batch'") @@ -44,15 +48,17 @@ def calculate_last_step_advantage( self, exps: Dict[str, Experience], precomputed_std: Optional[torch.Tensor] = None, - ) -> Tuple[Dict[str, float], Dict[str, float]]: + ) -> Tuple[Dict[str, float], Dict[str, float], bool]: """Calculate group advantage for a given group of experiences. Args: exps (Dict[str, Experience]): One experience per run, keyed by run ID. + precomputed_std (Optional[torch.Tensor]): Precomputed standard deviation for batch-level calculation. Returns: - Dict[str, float]: A tuple containing the scores for each run. + Dict[str, float]: Scores for each run. Dict[str, float]: Metrics for logging. + bool: Whether this group should be skipped. """ with torch.no_grad(): if len(exps) == 1: @@ -62,6 +68,13 @@ def calculate_last_step_advantage( rewards = torch.tensor([exp.reward for exp in exps.values()], dtype=torch.float32) group_reward_mean = torch.mean(rewards) group_reward_std = torch.std(rewards) + + # Determine if this group should be skipped based on std_threshold + should_skip = False + if self.std_threshold is not None: + if len(exps) == 1 or group_reward_std <= self.std_threshold: + should_skip = True + scores = {} for rid, exp in exps.items(): if self.std_cal_level == "batch" and precomputed_std is not None: @@ -73,7 +86,7 @@ def calculate_last_step_advantage( "reward_mean": group_reward_mean.item(), "reward_std": group_reward_std.item(), } - return scores, metrics + return scores, metrics, should_skip def broadcast_advantages( self, run_exps: Dict[str, List[Experience]], scores: Dict[str, float] @@ -102,6 +115,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: return [], {} cnt = 0 metric_list = [] + filtered_count = 0 # Step 1: split the experiences into sub-groups by task task_exps = group_by(exps, "task") @@ -126,14 +140,27 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: # Step 2: further split each task's experiences into sub-groups by run result_exps = [] + total_task_groups = len(task_exps) + skipped_task_groups = 0 + for task_exp in task_exps.values(): run_exps = group_by(task_exp, "run") # Step3: extract the last experience (last step) from each run and calculate scores last_step_exps = {run_id: step_exps[-1] for run_id, step_exps in run_exps.items()} - scores, metrics = self.calculate_last_step_advantage( + scores, metrics, should_skip = self.calculate_last_step_advantage( last_step_exps, precomputed_std=precomputed_std ) + + # Skip this task group if std is below threshold + if should_skip: + # Count all experiences in this task group as filtered + task_exp_count = sum(len(step_exps) for step_exps in run_exps.values()) + filtered_count += task_exp_count + skipped_task_groups += 1 + metric_list.append(metrics) + continue + metric_list.append(metrics) # Step 4: broadcast the advantages to all previous steps @@ -144,6 +171,14 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: metrics = gather_metrics(metric_list, "group_advantages") metrics["experience_count"] = cnt + metrics["filtered_count"] = filtered_count + + # Calculate the ratio of skipped task groups + if total_task_groups > 0: + metrics["skipped_group_ratio"] = skipped_task_groups / total_task_groups + else: + metrics["skipped_group_ratio"] = 0.0 + return result_exps, metrics def __call__(self, exps, **kwargs): @@ -160,4 +195,6 @@ def default_args(cls) -> Dict: return { "epsilon": 1e-6, "enable_step_norm": False, + "std_threshold": None, + "std_cal_level": "group", }