Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions tests/algorithm/advantage_fn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,71 @@ def test_batch_level_step_wise_grpo_advantage(self):
expected_advantages = expected_advantage_value * target_exp.action_mask
self.assertTrue(torch.allclose(target_exp.advantages, expected_advantages, atol=1e-6))
self.assertTrue(torch.allclose(target_exp.returns, expected_advantages, atol=1e-6))

def test_step_wise_grpo_with_std_threshold(self):
advantage_fn_cls = ADVANTAGE_FN.get("step_wise_grpo")
self.assertIsNotNone(advantage_fn_cls)
advantage_fn = advantage_fn_cls(epsilon=1e-6, std_threshold=0.0001)
repeat_times = 5
step_num = 4

# Create experiences with mixed reward patterns:
# - task 0: all runs have same reward (0.5) -> should be filtered
# - task 1: all runs have same reward (1.0) -> should be filtered
# - task 2: runs have different rewards (0, 1, 2, 3, 4) -> should NOT be filtered
exps = []

# Task 0: constant reward 0.5
for k in range(step_num):
for i in range(repeat_times):
exps.append(
Experience(
eid=EID(batch=0, task=0, run=i, step=k),
tokens=torch.zeros(5),
prompt_length=2,
reward=0.5,
)
)

# Task 1: constant reward 1.0
for k in range(step_num):
for i in range(repeat_times):
exps.append(
Experience(
eid=EID(batch=0, task=1, run=i, step=k),
tokens=torch.zeros(5),
prompt_length=2,
reward=1.0,
)
)

# Task 2: varying rewards
for k in range(step_num):
for i in range(repeat_times):
exps.append(
Experience(
eid=EID(batch=0, task=2, run=i, step=k),
tokens=torch.zeros(5),
prompt_length=2,
reward=float(i),
)
)

processed_exps, metrics = advantage_fn(exps)

# Only task 2 should remain (task 0 and task 1 filtered due to zero std)
expected_remaining = repeat_times * step_num # task 2 only
expected_filtered = 2 * repeat_times * step_num # task 0 and task 1

self.assertEqual(len(processed_exps), expected_remaining)
self.assertIn("filtered_count", metrics)
self.assertEqual(metrics["filtered_count"], expected_filtered)

# Verify skipped group ratio: 2 out of 3 tasks were skipped
self.assertIn("skipped_group_ratio", metrics)
expected_ratio = 2.0 / 3.0 # task 0 and task 1 skipped out of 3 total tasks
self.assertAlmostEqual(metrics["skipped_group_ratio"], expected_ratio, places=6)

# Verify that all remaining experiences are from task 2
for exp in processed_exps:
self.assertEqual(exp.eid.task, 2)
45 changes: 41 additions & 4 deletions trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
epsilon: float = 1e-6,
enable_step_norm: bool = False,
std_cal_level: str = "group", # 'group' (task-level) or 'batch'
std_threshold: Optional[float] = None,
**kwargs,
) -> None:
"""Initialize the Step-wise GRPO advantage function.
Expand All @@ -33,26 +34,31 @@ def __init__(
'group' (default): Std is calculated per task group.
'batch': Std is calculated across all last-step rewards in the entire batch.
The mean is always calculated per task group.
std_threshold (Optional[float]): If provided, task groups with a reward standard deviation
equal or below this threshold will be skipped.
"""
self.epsilon = epsilon
self.enable_step_norm = enable_step_norm
self.std_cal_level = std_cal_level
self.std_threshold = std_threshold
if self.std_cal_level not in ["group", "batch"]:
raise ValueError("std_cal_level must be either 'group' or 'batch'")

def calculate_last_step_advantage(
self,
exps: Dict[str, Experience],
precomputed_std: Optional[torch.Tensor] = None,
) -> Tuple[Dict[str, float], Dict[str, float]]:
) -> Tuple[Dict[str, float], Dict[str, float], bool]:
"""Calculate group advantage for a given group of experiences.

Args:
exps (Dict[str, Experience]): One experience per run, keyed by run ID.
precomputed_std (Optional[torch.Tensor]): Precomputed standard deviation for batch-level calculation.

Returns:
Dict[str, float]: A tuple containing the scores for each run.
Dict[str, float]: Scores for each run.
Dict[str, float]: Metrics for logging.
bool: Whether this group should be skipped.
"""
with torch.no_grad():
if len(exps) == 1:
Expand All @@ -62,6 +68,13 @@ def calculate_last_step_advantage(
rewards = torch.tensor([exp.reward for exp in exps.values()], dtype=torch.float32)
group_reward_mean = torch.mean(rewards)
group_reward_std = torch.std(rewards)

# Determine if this group should be skipped based on std_threshold
should_skip = False
if self.std_threshold is not None:
if len(exps) == 1 or group_reward_std <= self.std_threshold:
should_skip = True

scores = {}
for rid, exp in exps.items():
if self.std_cal_level == "batch" and precomputed_std is not None:
Expand All @@ -73,7 +86,7 @@ def calculate_last_step_advantage(
"reward_mean": group_reward_mean.item(),
"reward_std": group_reward_std.item(),
}
return scores, metrics
return scores, metrics, should_skip

def broadcast_advantages(
self, run_exps: Dict[str, List[Experience]], scores: Dict[str, float]
Expand Down Expand Up @@ -102,6 +115,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
return [], {}
cnt = 0
metric_list = []
filtered_count = 0
# Step 1: split the experiences into sub-groups by task
task_exps = group_by(exps, "task")

Expand All @@ -126,14 +140,27 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:

# Step 2: further split each task's experiences into sub-groups by run
result_exps = []
total_task_groups = len(task_exps)
skipped_task_groups = 0

for task_exp in task_exps.values():
run_exps = group_by(task_exp, "run")

# Step3: extract the last experience (last step) from each run and calculate scores
last_step_exps = {run_id: step_exps[-1] for run_id, step_exps in run_exps.items()}
scores, metrics = self.calculate_last_step_advantage(
scores, metrics, should_skip = self.calculate_last_step_advantage(
last_step_exps, precomputed_std=precomputed_std
)

# Skip this task group if std is below threshold
if should_skip:
# Count all experiences in this task group as filtered
task_exp_count = sum(len(step_exps) for step_exps in run_exps.values())
filtered_count += task_exp_count
skipped_task_groups += 1
metric_list.append(metrics)
continue

metric_list.append(metrics)

# Step 4: broadcast the advantages to all previous steps
Expand All @@ -144,6 +171,14 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:

metrics = gather_metrics(metric_list, "group_advantages")
metrics["experience_count"] = cnt
metrics["filtered_count"] = filtered_count

# Calculate the ratio of skipped task groups
if total_task_groups > 0:
metrics["skipped_group_ratio"] = skipped_task_groups / total_task_groups
else:
metrics["skipped_group_ratio"] = 0.0

return result_exps, metrics

def __call__(self, exps, **kwargs):
Expand All @@ -160,4 +195,6 @@ def default_args(cls) -> Dict:
return {
"epsilon": 1e-6,
"enable_step_norm": False,
"std_threshold": None,
"std_cal_level": "group",
}