diff --git a/tests/utils/debug/test_metrics.py b/tests/utils/debug/test_metrics.py index 1b2f7f8faa1..2caf9fea123 100644 --- a/tests/utils/debug/test_metrics.py +++ b/tests/utils/debug/test_metrics.py @@ -16,7 +16,7 @@ import torch from verl.protocol import DataProto -from verl.utils.debug.metrics import calculate_debug_metrics +from verl.utils.debug.metrics import _find_contiguous_segments, calculate_debug_metrics class TestMetrics(unittest.TestCase): @@ -43,6 +43,95 @@ def test_calculate_debug_metrics(self): print(metrics) assert metrics["training/rollout_probs_diff_valid"] == 1 + def test_find_contiguous_segments(self): + # Single segment + mask = torch.tensor([1, 1, 1, 0, 0]) + assert _find_contiguous_segments(mask) == [(0, 3)] + + # Multiple segments (multi-turn) + mask = torch.tensor([1, 1, 0, 0, 1, 1, 1, 0, 1]) + assert _find_contiguous_segments(mask) == [(0, 2), (4, 7), (8, 9)] + + # All zeros + mask = torch.tensor([0, 0, 0]) + assert _find_contiguous_segments(mask) == [] + + # All ones + mask = torch.tensor([1, 1, 1]) + assert _find_contiguous_segments(mask) == [(0, 3)] + + def test_per_round_metrics_single_turn(self): + """Single contiguous response should produce 1 round.""" + data = DataProto.from_dict( + { + "rollout_log_probs": torch.tensor([[-1.0, -2.0, -3.0, -4.0]]), + "old_log_probs": torch.tensor([[-1.1, -2.1, -3.1, -4.1]]), + "response_mask": torch.tensor([[1, 1, 1, 1]]), + "responses": torch.zeros((1, 4)), + } + ) + metrics = calculate_debug_metrics(data) + assert metrics["per_round/total_rounds"] == 1 + assert "per_round/round_0_abs_diff_mean" in metrics + self.assertAlmostEqual(metrics["per_round/round_0_abs_diff_mean"], 0.1, places=5) + + def test_per_round_metrics_multi_turn(self): + """Multi-turn: two rounds separated by env tokens.""" + # Round 0: positions 0-1, identical logprobs -> diff=0 + # Round 1: positions 4-5, different logprobs -> diff=1.0 + data = DataProto.from_dict( + { + "rollout_log_probs": torch.tensor([[-1.0, -2.0, -9.0, -9.0, -3.0, -4.0]]), + "old_log_probs": torch.tensor([[-1.0, -2.0, -9.0, -9.0, -4.0, -5.0]]), + "response_mask": torch.tensor([[1, 1, 0, 0, 1, 1]]), + "responses": torch.zeros((1, 6)), + } + ) + metrics = calculate_debug_metrics(data) + assert metrics["per_round/total_rounds"] == 2 + # Round 0: identical logprobs + self.assertAlmostEqual(metrics["per_round/round_0_abs_diff_mean"], 0.0, places=5) + # Round 1: diff of 1.0 each + self.assertAlmostEqual(metrics["per_round/round_1_abs_diff_mean"], 1.0, places=5) + # Max diff should be round 1 + assert metrics["per_round/max_round_diff"] == 1 + self.assertAlmostEqual(metrics["per_round/max_diff_value"], 1.0, places=5) + + def test_per_round_metrics_batch(self): + """Batch with different number of rounds per sample.""" + # Sample 0: 1 round (positions 0-2) + # Sample 1: 2 rounds (positions 0-1, positions 3-4) + data = DataProto.from_dict( + { + "rollout_log_probs": torch.tensor( + [ + [-1.0, -2.0, -3.0, -9.0, -9.0], + [-1.0, -2.0, -9.0, -3.0, -4.0], + ] + ), + "old_log_probs": torch.tensor( + [ + [-1.0, -2.0, -3.0, -9.0, -9.0], + [-1.0, -2.0, -9.0, -3.5, -4.5], + ] + ), + "response_mask": torch.tensor( + [ + [1, 1, 1, 0, 0], + [1, 1, 0, 1, 1], + ] + ), + "responses": torch.zeros((2, 5)), + } + ) + metrics = calculate_debug_metrics(data) + # Max rounds across batch is 2 + assert metrics["per_round/total_rounds"] == 2 + assert "per_round/round_0_abs_diff_mean" in metrics + assert "per_round/round_1_abs_diff_mean" in metrics + assert metrics["per_round/round_0_token_count"] == 5 # 3 from sample 0 + 2 from sample 1 + assert metrics["per_round/round_1_token_count"] == 2 # only from sample 1 + if __name__ == "__main__": unittest.main() diff --git a/verl/utils/debug/metrics.py b/verl/utils/debug/metrics.py index e7d57a2fec3..82ea3ce3404 100644 --- a/verl/utils/debug/metrics.py +++ b/verl/utils/debug/metrics.py @@ -60,6 +60,118 @@ def calculate_log_prob_diff(log_probs1: torch.Tensor, log_probs2: torch.Tensor, return torch.masked_select(full_diff, mask) +def _find_contiguous_segments(mask_1d: torch.Tensor) -> list[tuple[int, int]]: + """Find contiguous segments of 1s in a 1D mask tensor. + + Each contiguous segment of 1s represents one round of model generation + in a multi-turn trajectory. Segments are separated by 0s (environment + tokens like images, or padding). + + Example: + mask = [1,1,1,1, 0,0,0,0,0, 1,1,1, 0,0,0,0,0, 1,1,1, 0,0,0] + |--R0--| |--env--| |--R1-| |--env--| |--R2-| |pad| + Returns: [(0, 4), (9, 12), (17, 20)] + + Args: + mask_1d: 1D tensor with 0s and 1s + + Returns: + List of (start, end) tuples for each contiguous segment of 1s. + end is exclusive (Python slice convention). + """ + segments = [] + in_segment = False + start = 0 + + for i in range(len(mask_1d)): + val = mask_1d[i].item() if isinstance(mask_1d[i], torch.Tensor) else mask_1d[i] + if val == 1 and not in_segment: + in_segment = True + start = i + elif val == 0 and in_segment: + in_segment = False + segments.append((start, i)) + + if in_segment: + segments.append((start, len(mask_1d))) + + return segments + + +def _calculate_per_round_metrics( + train_log_probs: torch.Tensor, + rollout_log_probs: torch.Tensor, + response_mask: torch.Tensor, +) -> dict: + """Calculate per-round logprob mismatch metrics for multi-turn trajectories. + + Identifies rounds by finding contiguous segments of 1s in response_mask, + then computes mean absolute logprob difference per round. + + This is useful for multi-turn RL training where different rounds may have + different attention mask behavior (e.g., image window attention), causing + mismatch between training and rollout engines to vary across rounds. + + Args: + train_log_probs: Log probs from training engine (batch_size, seq_len) + rollout_log_probs: Log probs from rollout engine (batch_size, seq_len) + response_mask: Mask for valid positions (batch_size, seq_len), + 1=model generated token, 0=environment token or padding + + Returns: + Dictionary with per-round metrics: + - per_round/total_rounds: Max number of rounds across batch + - per_round/round_{i}_abs_diff_mean: Mean |logprob_train - logprob_rollout| for round i + - per_round/round_{i}_token_count: Number of tokens in round i + - per_round/max_round_diff: Which round has the largest mean diff + - per_round/max_diff_value: The largest mean diff value + """ + batch_size = train_log_probs.shape[0] + + # round_idx -> list of (train_vals, rollout_vals) + all_round_data: dict[int, list[tuple[torch.Tensor, torch.Tensor]]] = {} + max_rounds = 0 + + for b in range(batch_size): + segments = _find_contiguous_segments(response_mask[b]) + max_rounds = max(max_rounds, len(segments)) + + for round_idx, (start, end) in enumerate(segments): + if round_idx not in all_round_data: + all_round_data[round_idx] = [] + all_round_data[round_idx].append((train_log_probs[b, start:end], rollout_log_probs[b, start:end])) + + if not all_round_data: + return {"per_round/total_rounds": 0} + + metrics: dict = {"per_round/total_rounds": max_rounds} + max_diff = -1.0 + max_diff_round = -1 + + for round_idx in sorted(all_round_data.keys()): + train_all = torch.cat([t for t, _ in all_round_data[round_idx]]) + rollout_all = torch.cat([r for _, r in all_round_data[round_idx]]) + + if train_all.numel() == 0: + continue + + abs_diff = torch.abs(train_all - rollout_all) + mean_diff = abs_diff.mean().item() + + metrics[f"per_round/round_{round_idx}_abs_diff_mean"] = mean_diff + metrics[f"per_round/round_{round_idx}_token_count"] = train_all.numel() + + if mean_diff > max_diff: + max_diff = mean_diff + max_diff_round = round_idx + + metrics["per_round/max_round_diff"] = max_diff_round + if max_diff_round >= 0: + metrics["per_round/max_diff_value"] = max_diff + + return metrics + + def calculate_debug_metrics(data: DataProto) -> dict: """ calculate rollout vs actor logprobs diff, for debugging purpose @@ -100,10 +212,17 @@ def calculate_debug_metrics(data: DataProto) -> dict: response_mask_bool = response_mask.bool() pearson_corrcoef = pearson_correlation_coefficient(actor_probs, rollout_probs, response_mask_bool) rollout_probs_diff = calculate_log_prob_diff(actor_probs, rollout_probs, response_mask_bool) - return { + + metrics = { "training/rollout_probs_diff_valid": 1, "training/rollout_probs_diff_max": torch.max(rollout_probs_diff).detach().item(), "training/rollout_probs_diff_mean": torch.mean(rollout_probs_diff).detach().item(), "training/rollout_probs_diff_std": torch.std(rollout_probs_diff).detach().item(), "training/rollout_actor_probs_pearson_corr": pearson_corrcoef, } + + # Per-round logprob mismatch metrics for multi-turn trajectories + per_round_metrics = _calculate_per_round_metrics(actor_old_log_probs, rollout_old_log_probs, response_mask) + metrics.update(per_round_metrics) + + return metrics