From cdad586399e4463cef0443ed4587ac8e8e641aa1 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Wed, 3 Dec 2025 01:14:02 -0800 Subject: [PATCH 1/5] collect isl_osl metrics Signed-off-by: Youngeun Kwon --- nemo_rl/experience/rollouts.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index b8b378542c..9f58d09024 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -21,6 +21,7 @@ import statistics from collections import defaultdict from dataclasses import dataclass +from functools import reduce from typing import Any, Optional import ray @@ -654,6 +655,8 @@ async def run_sample_multi_turn_rollout( # Track per-turn metrics turn_gen_tokens = [] + turn_input_tokens = [] + turn_total_tokens = [] # Track per-turn per-worker token accounting if available per_worker_token_counts = {} # worker_idx -> token_count @@ -685,6 +688,8 @@ async def run_sample_multi_turn_rollout( assistant_token_count += gen_token_count token_count += gen_token_count turn_gen_tokens.append(gen_token_count) + turn_input_tokens.append(int(input_lengths)) + turn_total_tokens.append(int(input_lengths) + gen_token_count) # Per-worker load accounting if "gen_leader_worker_idx" in gen_metrics: worker_idx = int(gen_metrics["gen_leader_worker_idx"]) @@ -770,6 +775,8 @@ async def run_sample_multi_turn_rollout( "max_turns_reached": max_turns_reached, "total_reward": total_reward, "turn_gen_tokens": turn_gen_tokens, + "turn_input_tokens": turn_input_tokens, + "turn_total_tokens": turn_total_tokens, # Pass-through per-worker per-turn accounting for aggregation at batch level "per_worker_token_counts": per_worker_token_counts, } @@ -930,6 +937,17 @@ async def run_single_sample_with_error_handling(i, sample_state): per_worker_token_counts[k] = per_worker_token_counts.get(k, 0) + v rollout_metrics["per_worker_token_counts"] = per_worker_token_counts + # Collect ISL, OSL, and ISL+OSL metrics for all samples + rollout_metrics["gen_tokens_lengths"] = reduce( + lambda x, y: x + y, [m["turn_gen_tokens"] for m in all_sample_metrics] + ) + rollout_metrics["input_tokens_lengths"] = reduce( + lambda x, y: x + y, [m["turn_input_tokens"] for m in all_sample_metrics] + ) + rollout_metrics["total_tokens_lengths"] = reduce( + lambda x, y: x + y, [m["turn_total_tokens"] for m in all_sample_metrics] + ) + return final_batch, rollout_metrics return asyncio.run(_async_rollout_implementation()) From 49ebf4fdfe55254421a25e7226905c70cdf20bfc Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Wed, 3 Dec 2025 01:17:02 -0800 Subject: [PATCH 2/5] rename vllm_logger_metrics Signed-off-by: Youngeun Kwon --- nemo_rl/algorithms/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index da1572b902..17c69e479a 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -748,24 +748,24 @@ def visualize_per_worker_timeline( def log_generation_metrics_to_wandb( - vllm_logger_metrics: dict[str, dict[int, list[Any]]], + generation_logger_metrics: dict[str, dict[int, list[Any]]], step: int, timeline_interval: float, logger: Logger, ) -> None: - """Log vLLM metrics to wandb. + """Log generation metrics to wandb. Args: - vllm_logger_metrics: Dictionary of vLLM logger metrics + generation_logger_metrics: Dictionary of generation logger metrics step: Global step value timeline_interval: Interval between timeline points (in seconds) logger: Logger instance """ - for vllm_metric in vllm_logger_metrics.keys(): + for generation_metric in generation_logger_metrics.keys(): logger.log_plot_per_worker_timeline_metrics( - vllm_logger_metrics[vllm_metric], + generation_logger_metrics[generation_metric], step=step, - prefix="vllm_metrics", - name=vllm_metric, + prefix="generation_metrics", + name=generation_metric, timeline_interval=timeline_interval, ) From 1e67087990f53d9c075ea9434732b8175b4e40b1 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Wed, 3 Dec 2025 01:49:48 -0800 Subject: [PATCH 3/5] log histogram Signed-off-by: Youngeun Kwon --- nemo_rl/algorithms/grpo.py | 35 +++++++++++++++++++++++++++++++++++ nemo_rl/algorithms/utils.py | 22 ++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 23fd2a12f0..71872ff940 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -40,6 +40,7 @@ from nemo_rl.algorithms.utils import ( calculate_baseline_and_std_per_prompt, log_generation_metrics_to_wandb, + log_histogram_metrics_to_wandb, print_performance_metrics, set_seed, ) @@ -1566,6 +1567,23 @@ def grpo_train( logger, ) + # Plot ISL/OSL/ISL+OSL histograms to wandb + try: + for hist_metrics in [ + "gen_tokens_lengths", + "input_tokens_lengths", + "total_tokens_lengths", + ]: + log_histogram_metrics_to_wandb( + f"generation_metrics/{hist_metrics}", + metrics[hist_metrics], + total_steps + 1, + logger, + ) + except Exception as e: + print(f"āŒ Error plotting histograms to wandb: {e}") + pass + print("\nšŸ“Š Training Results:") print(f" • Loss: {metrics['loss']:.4f}") @@ -2489,6 +2507,23 @@ def async_grpo_train( logger, ) + # Plot ISL/OSL/ISL+OSL histograms to wandb + try: + for hist_metrics in [ + "gen_tokens_lengths", + "input_tokens_lengths", + "total_tokens_lengths", + ]: + log_histogram_metrics_to_wandb( + f"generation_metrics/{hist_metrics}", + metrics[hist_metrics], + step + 1, + logger, + ) + except Exception as e: + print(f"āŒ Error plotting histograms to wandb: {e}") + pass + print("\nšŸ“Š Training Results:") print(f" • Loss: {metrics['loss']:.4f}") print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 17c69e479a..b0316e303f 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -769,3 +769,25 @@ def log_generation_metrics_to_wandb( name=generation_metric, timeline_interval=timeline_interval, ) + + +def log_histogram_metrics_to_wandb( + metric_name: str, + metric_values: list[Any], + step: int, + logger: Logger, +) -> None: + """Log histogram metrics to wandb. + + Args: + metric_name: Name of the metric + metric_values: List of metric values + step: Global step value + logger: Logger instance + """ + if logger.wandb_logger: + import wandb # pyright: ignore[reportMissingImports] + + logger.wandb_logger.run.log( + {metric_name: wandb.Histogram(metric_values)}, step=step + ) From abcbd883816ceb1108f375a103a48e396aa3cae5 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Thu, 4 Dec 2025 11:18:20 -0800 Subject: [PATCH 4/5] adjust comments Signed-off-by: Youngeun Kwon --- nemo_rl/algorithms/grpo.py | 55 +++++++++++++++------------------- nemo_rl/algorithms/utils.py | 22 -------------- nemo_rl/experience/rollouts.py | 19 ++++++------ nemo_rl/utils/logger.py | 26 ++++++++++++++++ 4 files changed, 59 insertions(+), 63 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 71872ff940..81d64642ce 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -40,7 +40,6 @@ from nemo_rl.algorithms.utils import ( calculate_baseline_and_std_per_prompt, log_generation_metrics_to_wandb, - log_histogram_metrics_to_wandb, print_performance_metrics, set_seed, ) @@ -1568,21 +1567,18 @@ def grpo_train( ) # Plot ISL/OSL/ISL+OSL histograms to wandb - try: - for hist_metrics in [ - "gen_tokens_lengths", - "input_tokens_lengths", - "total_tokens_lengths", - ]: - log_histogram_metrics_to_wandb( - f"generation_metrics/{hist_metrics}", - metrics[hist_metrics], - total_steps + 1, - logger, - ) - except Exception as e: - print(f"āŒ Error plotting histograms to wandb: {e}") - pass + if ( + master_config["policy"]["generation"] + .get("vllm_cfg", {}) + .get("async_engine", False) + ): + for metric_name in metrics.keys(): + if metric_name.startswith("histogram/"): + logger.log_histogram( + metrics[metric_name], + total_steps + 1, + f"generation_metrics/{metric_name}", + ) print("\nšŸ“Š Training Results:") @@ -2508,21 +2504,18 @@ def async_grpo_train( ) # Plot ISL/OSL/ISL+OSL histograms to wandb - try: - for hist_metrics in [ - "gen_tokens_lengths", - "input_tokens_lengths", - "total_tokens_lengths", - ]: - log_histogram_metrics_to_wandb( - f"generation_metrics/{hist_metrics}", - metrics[hist_metrics], - step + 1, - logger, - ) - except Exception as e: - print(f"āŒ Error plotting histograms to wandb: {e}") - pass + if ( + master_config["policy"]["generation"] + .get("vllm_cfg", {}) + .get("async_engine", False) + ): + for metric_name in metrics.keys(): + if metric_name.startswith("histogram/"): + logger.log_histogram( + metrics[metric_name], + step + 1, + f"generation_metrics/{metric_name}", + ) print("\nšŸ“Š Training Results:") print(f" • Loss: {metrics['loss']:.4f}") diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index b0316e303f..17c69e479a 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -769,25 +769,3 @@ def log_generation_metrics_to_wandb( name=generation_metric, timeline_interval=timeline_interval, ) - - -def log_histogram_metrics_to_wandb( - metric_name: str, - metric_values: list[Any], - step: int, - logger: Logger, -) -> None: - """Log histogram metrics to wandb. - - Args: - metric_name: Name of the metric - metric_values: List of metric values - step: Global step value - logger: Logger instance - """ - if logger.wandb_logger: - import wandb # pyright: ignore[reportMissingImports] - - logger.wandb_logger.run.log( - {metric_name: wandb.Histogram(metric_values)}, step=step - ) diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 9f58d09024..49666f92ef 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -21,7 +21,6 @@ import statistics from collections import defaultdict from dataclasses import dataclass -from functools import reduce from typing import Any, Optional import ray @@ -938,15 +937,15 @@ async def run_single_sample_with_error_handling(i, sample_state): rollout_metrics["per_worker_token_counts"] = per_worker_token_counts # Collect ISL, OSL, and ISL+OSL metrics for all samples - rollout_metrics["gen_tokens_lengths"] = reduce( - lambda x, y: x + y, [m["turn_gen_tokens"] for m in all_sample_metrics] - ) - rollout_metrics["input_tokens_lengths"] = reduce( - lambda x, y: x + y, [m["turn_input_tokens"] for m in all_sample_metrics] - ) - rollout_metrics["total_tokens_lengths"] = reduce( - lambda x, y: x + y, [m["turn_total_tokens"] for m in all_sample_metrics] - ) + rollout_metrics["histogram/gen_tokens_length"] = [ + t for m in all_sample_metrics for t in m["turn_gen_tokens"] + ] + rollout_metrics["histogram/input_tokens_length"] = [ + t for m in all_sample_metrics for t in m["turn_input_tokens"] + ] + rollout_metrics["histogram/total_tokens_length"] = [ + t for m in all_sample_metrics for t in m["turn_total_tokens"] + ] return final_batch, rollout_metrics diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index 736a6abc7b..6580e25f42 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -108,6 +108,11 @@ def log_hyperparams(self, params: Mapping[str, Any]) -> None: """Log dictionary of hyperparameters.""" pass + @abstractmethod + def log_histogram(self, histogram: list[Any], step: int, name: str) -> None: + """Log histogram metrics.""" + pass + class TensorboardLogger(LoggerInterface): """Tensorboard logger backend.""" @@ -350,6 +355,16 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None: """ self.run.log({name: figure}, step=step) + def log_histogram(self, histogram: list[Any], step: int, name: str) -> None: + """Log histogram metrics to wandb. + + Args: + histogram: List of histogram values + step: Global step value + name: Name of the metric + """ + self.run.log({name: wandb.Histogram(histogram)}, step=step) + class SwanlabLogger(LoggerInterface): """SwanLab logger backend.""" @@ -1017,6 +1032,17 @@ def log_plot_per_worker_timeline_metrics( logger.log_plot(fig, step, f"{prefix}/average_{name}") plt.close(fig) + def log_histogram(self, histogram: list[Any], step: int, name: str) -> None: + """Log histogram metrics to all backends if available. + + Args: + histogram: List of histogram values + step: Global step value + name: Name of the metric + """ + for logger in self.loggers: + logger.log_histogram(histogram, step, name) + def log_plot_token_mult_prob_error( self, data: dict[str, Any], step: int, name: str ) -> None: From 6b380cae91017b9ac8dd3dd0a9b7901f548807b7 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Thu, 4 Dec 2025 13:49:55 -0800 Subject: [PATCH 5/5] fix ci Signed-off-by: Youngeun Kwon --- nemo_rl/utils/logger.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index 6580e25f42..5586312b9d 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -158,6 +158,10 @@ def log_metrics( print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}") continue + def log_histogram(self, histogram: list[Any], step: int, name: str) -> None: + """Log histogram metrics to Tensorboard.""" + return + def log_hyperparams(self, params: Mapping[str, Any]) -> None: """Log hyperparameters to Tensorboard. @@ -434,6 +438,10 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None: """ self.run.log({name: figure}, step=step) + def log_histogram(self, histogram: list[Any], step: int, name: str) -> None: + """Log histogram metrics to swanlab.""" + return + class GpuMetricSnapshot(TypedDict): step: int @@ -808,6 +816,10 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None: figure.savefig(tmp_file.name, format="png", bbox_inches="tight") mlflow.log_artifact(tmp_file.name, f"plots/{name}") + def log_histogram(self, histogram: list[Any], step: int, name: str) -> None: + """Log histogram metrics to MLflow.""" + return + def __del__(self) -> None: """Clean up resources when the logger is destroyed.""" try: