diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d69b45fdfb..65e0b92591 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1779,9 +1779,6 @@ def async_grpo_train( train_results, metrics, timing_metrics, master_config ) - if "per_worker_token_counts" in metrics: - del metrics["per_worker_token_counts"] - logger.log_metrics(performance_metrics, step + 1, prefix="performance") logger.log_metrics(metrics, step + 1, prefix="train") logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index e15d731efb..62b92890ed 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -553,6 +553,15 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float: total_tflops / theoretical_tflops ) + # ===================================================== + # Clean up metrics + # ===================================================== + + # Clean up metrics to avoid wandb logging errors + # Dict structures cannot be logged to wandb + if "per_worker_token_counts" in metrics: + del metrics["per_worker_token_counts"] + # ===================================================== # Logging # =====================================================