diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 75766a3fdc..473740df1f 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -39,6 +39,7 @@ ) from nemo_rl.algorithms.utils import ( calculate_baseline_and_std_per_prompt, + log_generation_metrics_to_wandb, print_performance_metrics, set_seed, ) @@ -1475,6 +1476,18 @@ def grpo_train( total_steps + 1, name="train/token_mult_prob_error_plot_sample", ) + if master_config["policy"]["generation"].get("vllm_cfg", {}).get( + "enable_vllm_metrics_logger", False + ) and master_config.get("logger", {}).get("wandb_enabled", False): + log_generation_metrics_to_wandb( + vllm_logger_metrics, + total_steps + 1, + master_config["policy"]["generation"]["vllm_cfg"][ + "vllm_metrics_logger_interval" + ], + logger, + ) + print("\nšŸ“Š Training Results:") print(f" • Loss: {metrics['loss']:.4f}") @@ -2386,6 +2399,18 @@ def async_grpo_train( metrics["buffer_size"] = buffer_size_current metrics["avg_trajectory_age"] = avg_trajectory_age + if master_config["policy"]["generation"].get("vllm_cfg", {}).get( + "enable_vllm_metrics_logger", False + ) and master_config.get("logger", {}).get("wandb_enabled", False): + log_generation_metrics_to_wandb( + vllm_logger_metrics, + step + 1, + master_config["policy"]["generation"]["vllm_cfg"][ + "vllm_metrics_logger_interval" + ], + logger, + ) + print("\nšŸ“Š Training Results:") print(f" • Loss: {metrics['loss']:.4f}") print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 534173e523..da1572b902 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -28,6 +28,7 @@ from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES from nemo_rl.models.policy import TokenizerConfig +from nemo_rl.utils.logger import Logger def calculate_kl( @@ -744,3 +745,27 @@ def visualize_per_worker_timeline( ) return performance_metrics + + +def log_generation_metrics_to_wandb( + vllm_logger_metrics: dict[str, dict[int, list[Any]]], + step: int, + timeline_interval: float, + logger: Logger, +) -> None: + """Log vLLM metrics to wandb. + + Args: + vllm_logger_metrics: Dictionary of vLLM logger metrics + step: Global step value + timeline_interval: Interval between timeline points (in seconds) + logger: Logger instance + """ + for vllm_metric in vllm_logger_metrics.keys(): + logger.log_plot_per_worker_timeline_metrics( + vllm_logger_metrics[vllm_metric], + step=step, + prefix="vllm_metrics", + name=vllm_metric, + timeline_interval=timeline_interval, + ) diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 87e480a31e..1525cbaee3 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -838,9 +838,11 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: dp_indices.append(dp_idx) results = ray.get(futures) - vllm_logger_metrics: dict[str, dict[int, list[int]]] = { + vllm_logger_metrics: dict[str, dict[int, list[Any]]] = { "inflight_batch_sizes": {}, # dp_idx -> list[int] "num_pending_samples": {}, # dp_idx -> list[int] + "kv_cache_usage_perc": {}, # dp_idx -> list[float] + "generation_tokens": {}, # dp_idx -> list[int] } for dp_idx, stats in zip(dp_indices, results): @@ -854,6 +856,12 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: num_pending_samples = stats.get("num_pending_samples") if num_pending_samples: vllm_logger_metrics["num_pending_samples"][dp_idx] = num_pending_samples + kv_cache_usage_perc = stats.get("kv_cache_usage_perc") + if kv_cache_usage_perc: + vllm_logger_metrics["kv_cache_usage_perc"][dp_idx] = kv_cache_usage_perc + generation_tokens = stats.get("generation_tokens") + if generation_tokens: + vllm_logger_metrics["generation_tokens"][dp_idx] = generation_tokens return vllm_logger_metrics diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 93910b72fb..4887984e8a 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -167,7 +167,7 @@ def _start_vllm_metrics_logger(self) -> None: Controlled by vllm_metrics_logger_interval (default: 0.5) in vllm_cfg. Runs only on the model-owner actor. """ - from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot + from vllm.v1.metrics.reader import Gauge, Counter, get_metrics_snapshot assert self.cfg["vllm_cfg"].get("async_engine", False), ( "vLLM metrics logger is only supported with async engine enabled" @@ -190,6 +190,8 @@ def _start_vllm_metrics_logger(self) -> None: self.inflight_batch_sizes: list[int] = [] self.num_pending_samples: list[int] = [] + self.kv_cache_usage_perc: list[float] = [] + self.generation_tokens: list[int] = [] def _logger_loop(): # Delay a little to let engine settle @@ -197,15 +199,20 @@ def _logger_loop(): while True: try: for m in get_metrics_snapshot(): - if isinstance(m, Gauge): - # Log the vllm inflight batch sizes - if m.name == "vllm:num_requests_running": - with self._vllm_metrics_lock: + with self._vllm_metrics_lock: + if isinstance(m, Gauge): + # Log the vllm inflight batch sizes + if m.name == "vllm:num_requests_running": self.inflight_batch_sizes.append(int(m.value)) - # Log the vllm pending number of requests in the queue - elif m.name == "vllm:num_requests_waiting": - with self._vllm_metrics_lock: + # Log the vllm pending number of requests in the queue + elif m.name == "vllm:num_requests_waiting": self.num_pending_samples.append(int(m.value)) + # Log the vllm kv cache usage + elif m.name == "vllm:kv_cache_usage_perc": + self.kv_cache_usage_perc.append(float(m.value)) + elif isinstance(m, Counter): + if m.name == "vllm:generation_tokens": + self.generation_tokens.append(int(m.value)) except Exception: print( "āš ļø[vLLM Metric Logger] Exception in vLLM metrics logger", @@ -232,6 +239,8 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: metric = { "inflight_batch_sizes": copy.deepcopy(self.inflight_batch_sizes), "num_pending_samples": copy.deepcopy(self.num_pending_samples), + "kv_cache_usage_perc": copy.deepcopy(self.kv_cache_usage_perc), + "generation_tokens": copy.deepcopy(self.generation_tokens), } return metric @@ -242,6 +251,8 @@ def clear_vllm_logger_metrics(self) -> None: with self._vllm_metrics_lock: self.inflight_batch_sizes = [] self.num_pending_samples = [] + self.kv_cache_usage_perc = [] + self.generation_tokens = [] async def post_init_async(self): self.vllm_device_ids = await self.report_device_id_async() diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index eef08b2e70..736a6abc7b 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -26,6 +26,7 @@ from typing import Any, Callable, Mapping, NotRequired, Optional, TypedDict import mlflow +import numpy as np import ray import requests import swanlab @@ -935,6 +936,87 @@ def log_batched_dict_as_jsonl( print(f"Logged data to {filepath}") + def log_plot_per_worker_timeline_metrics( + self, + metrics: dict[int, list[Any]], + step: int, + prefix: str, + name: str, + timeline_interval: float, + ) -> None: + """Log a plot of per-worker timeline metrics. + + Args: + metrics: Dictionary of metrics to log, where the keys are the worker IDs and the values are the lists of metric values + - metrics: dict[str, list[Any]] = {worker_id: [metric_value_1, metric_value_2, ...]} + - metric values are time series values over time, the timing gap between the values is the timeline_interval + step: Global step value + name: Name of the plot + timeline_interval: Interval between timeline points (in seconds) + """ + if not metrics: + print( + f"Skipping {name} per-worker timeline logging because no metrics were provided." + ) + return + + if timeline_interval <= 0: + raise ValueError( + f"timeline_interval must be positive; received {timeline_interval}" + ) + + # Plot the per-worker timeline metrics + x_series: list[list[float]] = [] + y_series: list[list[float]] = [] + series_labels: list[str] = [] + + if not any(metrics.values()): + print( + f"Skipping {name} per-worker timeline logging because all series were empty." + ) + return + + for worker_id in sorted(metrics.keys()): + metric_values = metrics[worker_id] + if not metric_values: + continue + + x_series.append([i * timeline_interval for i in range(len(metric_values))]) + y_series.append([float(v) for v in metric_values]) + series_labels.append(f"worker_{worker_id}") + + fig, ax = plt.subplots() + for label, xs, ys in zip(series_labels, x_series, y_series): + ax.plot(xs, ys, label=label) + + ax.set_xlabel("Time (s)") + ax.set_ylabel(f"{name} (per worker)") + ax.set_title(name) + ax.grid(True, alpha=0.2) + fig.tight_layout() + + for logger in self.loggers: + logger.log_plot(fig, step, f"{prefix}/per_worker_{name}") + plt.close(fig) + + # Plot the average of the metrics + min_length = min(len(v) for v in metrics.values()) + x_series = [i * timeline_interval for i in range(min_length)] + truncated_y_serise = [v[:min_length] for v in y_series] + + avg_y_serise = np.mean(truncated_y_serise, axis=0) + + fig, ax = plt.subplots() + ax.plot(x_series, avg_y_serise, label="average") + ax.set_xlabel("Time (s)") + ax.set_ylabel(f"{name} (average)") + ax.set_title(name) + ax.grid(True, alpha=0.2) + fig.tight_layout() + for logger in self.loggers: + logger.log_plot(fig, step, f"{prefix}/average_{name}") + plt.close(fig) + def log_plot_token_mult_prob_error( self, data: dict[str, Any], step: int, name: str ) -> None: diff --git a/tests/unit/utils/test_logger.py b/tests/unit/utils/test_logger.py index c9771cea0c..2fc3d3c240 100644 --- a/tests/unit/utils/test_logger.py +++ b/tests/unit/utils/test_logger.py @@ -14,7 +14,7 @@ import shutil import tempfile -from unittest.mock import patch +from unittest.mock import MagicMock, call, patch import pytest import torch @@ -1741,6 +1741,97 @@ def test_log_hyperparams_with_mlflow( mock_mlflow_instance.log_hyperparams.assert_called_once_with(params) mock_swanlab_instance.log_hyperparams.assert_called_once_with(params) + def test_log_plot_per_worker_timeline_metrics_logs_expected_series(self): + """Ensure per-worker and average plots are produced and logged.""" + logger = Logger.__new__(Logger) + backend_logger = MagicMock() + logger.loggers = [backend_logger] + + metrics = { + 0: [1, 2, 3], + 1: [2, 3, 4], + } + + mock_fig_worker, mock_ax_worker = MagicMock(), MagicMock() + mock_fig_avg, mock_ax_avg = MagicMock(), MagicMock() + + with ( + patch( + "nemo_rl.utils.logger.plt.subplots", + side_effect=[ + (mock_fig_worker, mock_ax_worker), + (mock_fig_avg, mock_ax_avg), + ], + ) as mock_subplots, + patch("nemo_rl.utils.logger.plt.close") as mock_close, + ): + logger.log_plot_per_worker_timeline_metrics( + metrics, + step=1, + prefix="vllm", + name="kv_cache", + timeline_interval=0.5, + ) + + assert mock_subplots.call_count == 2 + expected_x = [0.0, 0.5, 1.0] + mock_ax_worker.plot.assert_has_calls( + [ + call(expected_x, [1.0, 2.0, 3.0], label="worker_0"), + call(expected_x, [2.0, 3.0, 4.0], label="worker_1"), + ], + any_order=False, + ) + + avg_call = mock_ax_avg.plot.call_args_list[0] + assert avg_call.args[0] == expected_x + assert avg_call.args[1].tolist() == [1.5, 2.5, 3.5] + assert avg_call.kwargs["label"] == "average" + + backend_logger.log_plot.assert_has_calls( + [ + call(mock_fig_worker, 1, "vllm/per_worker_kv_cache"), + call(mock_fig_avg, 1, "vllm/average_kv_cache"), + ], + any_order=False, + ) + assert mock_close.call_args_list == [ + call(mock_fig_worker), + call(mock_fig_avg), + ] + + def test_log_plot_per_worker_timeline_metrics_requires_positive_interval(self): + """timeline_interval must be positive.""" + logger = Logger.__new__(Logger) + logger.loggers = [MagicMock()] + + with pytest.raises(ValueError): + logger.log_plot_per_worker_timeline_metrics( + metrics={0: [1, 2]}, + step=1, + prefix="train", + name="pending", + timeline_interval=0.0, + ) + + def test_log_plot_per_worker_timeline_metrics_skips_when_no_data(self): + """No plots should be produced when metrics are empty.""" + logger = Logger.__new__(Logger) + backend_logger = MagicMock() + logger.loggers = [backend_logger] + + with patch("nemo_rl.utils.logger.plt.subplots") as mock_subplots: + logger.log_plot_per_worker_timeline_metrics( + metrics={}, + step=1, + prefix="train", + name="pending", + timeline_interval=1.0, + ) + + mock_subplots.assert_not_called() + backend_logger.log_plot.assert_not_called() + def test_print_message_log_samples(capsys): """Test that print_message_log_samples displays full content correctly."""