diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index c39dc98535..0411f7500c 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1777,9 +1777,6 @@ def async_grpo_train( train_results, metrics, timing_metrics, master_config ) - if "per_worker_token_counts" in metrics: - del metrics["per_worker_token_counts"] - logger.log_metrics(performance_metrics, step + 1, prefix="performance") logger.log_metrics(metrics, step + 1, prefix="train") logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index e15d731efb..62b92890ed 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -553,6 +553,15 @@ def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float: total_tflops / theoretical_tflops ) + # ===================================================== + # Clean up metrics + # ===================================================== + + # Clean up metrics to avoid wandb logging errors + # Dict structures cannot be logged to wandb + if "per_worker_token_counts" in metrics: + del metrics["per_worker_token_counts"] + # ===================================================== # Logging # =====================================================