diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 6d2ec56dc4..5e514ae4d7 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1977,6 +1977,7 @@ def async_grpo_train( "reward", "global_valid_seqs", "global_valid_toks", + "mean_prompt_length", }: metrics[k] = np.mean(v).item() else: diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 62b92890ed..529e165afb 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -401,6 +401,11 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float: average_token_imbalance = visualize_per_worker_load(per_worker_token_counts) performance_metrics["average_token_imbalance"] = average_token_imbalance + if "mean_total_tokens_per_sample" in metrics: + print( + f" • Mean Total Tokens per Sample: {metrics['mean_total_tokens_per_sample']:.2f}" + ) + # ===================================================== # Throughputs # =====================================================