Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from nemo_rl.algorithms.utils import (
calculate_baseline_and_std_per_prompt,
log_generation_metrics_to_wandb,
print_performance_metrics,
set_seed,
)
Expand Down Expand Up @@ -1475,6 +1476,18 @@ def grpo_train(
total_steps + 1,
name="train/token_mult_prob_error_plot_sample",
)
if master_config["policy"]["generation"].get("vllm_cfg", {}).get(
"enable_vllm_metrics_logger", False
) and master_config.get("logger", {}).get("wandb_enabled", False):
log_generation_metrics_to_wandb(
vllm_logger_metrics,
total_steps + 1,
master_config["policy"]["generation"]["vllm_cfg"][
"vllm_metrics_logger_interval"
],
logger,
)

print("\n📊 Training Results:")

print(f" • Loss: {metrics['loss']:.4f}")
Expand Down Expand Up @@ -2386,6 +2399,18 @@ def async_grpo_train(
metrics["buffer_size"] = buffer_size_current
metrics["avg_trajectory_age"] = avg_trajectory_age

if master_config["policy"]["generation"].get("vllm_cfg", {}).get(
"enable_vllm_metrics_logger", False
) and master_config.get("logger", {}).get("wandb_enabled", False):
log_generation_metrics_to_wandb(
vllm_logger_metrics,
step + 1,
master_config["policy"]["generation"]["vllm_cfg"][
"vllm_metrics_logger_interval"
],
logger,
)

print("\n📊 Training Results:")
print(f" • Loss: {metrics['loss']:.4f}")
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")
Expand Down
25 changes: 25 additions & 0 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES
from nemo_rl.models.policy import TokenizerConfig
from nemo_rl.utils.logger import Logger


def calculate_kl(
Expand Down Expand Up @@ -744,3 +745,27 @@ def visualize_per_worker_timeline(
)

return performance_metrics


def log_generation_metrics_to_wandb(
vllm_logger_metrics: dict[str, dict[int, list[Any]]],
step: int,
timeline_interval: float,
logger: Logger,
) -> None:
"""Log vLLM metrics to wandb.

Args:
vllm_logger_metrics: Dictionary of vLLM logger metrics
step: Global step value
timeline_interval: Interval between timeline points (in seconds)
logger: Logger instance
"""
for vllm_metric in vllm_logger_metrics.keys():
logger.log_plot_per_worker_timeline_metrics(
vllm_logger_metrics[vllm_metric],
step=step,
prefix="vllm_metrics",
name=vllm_metric,
timeline_interval=timeline_interval,
)
10 changes: 9 additions & 1 deletion nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,9 +838,11 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
dp_indices.append(dp_idx)

results = ray.get(futures)
vllm_logger_metrics: dict[str, dict[int, list[int]]] = {
vllm_logger_metrics: dict[str, dict[int, list[Any]]] = {
"inflight_batch_sizes": {}, # dp_idx -> list[int]
"num_pending_samples": {}, # dp_idx -> list[int]
"kv_cache_usage_perc": {}, # dp_idx -> list[float]
"generation_tokens": {}, # dp_idx -> list[int]
}

for dp_idx, stats in zip(dp_indices, results):
Expand All @@ -854,6 +856,12 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
num_pending_samples = stats.get("num_pending_samples")
if num_pending_samples:
vllm_logger_metrics["num_pending_samples"][dp_idx] = num_pending_samples
kv_cache_usage_perc = stats.get("kv_cache_usage_perc")
if kv_cache_usage_perc:
vllm_logger_metrics["kv_cache_usage_perc"][dp_idx] = kv_cache_usage_perc
generation_tokens = stats.get("generation_tokens")
if generation_tokens:
vllm_logger_metrics["generation_tokens"][dp_idx] = generation_tokens

return vllm_logger_metrics

Expand Down
27 changes: 19 additions & 8 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _start_vllm_metrics_logger(self) -> None:
Controlled by vllm_metrics_logger_interval (default: 0.5) in vllm_cfg.
Runs only on the model-owner actor.
"""
from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot
from vllm.v1.metrics.reader import Gauge, Counter, get_metrics_snapshot

assert self.cfg["vllm_cfg"].get("async_engine", False), (
"vLLM metrics logger is only supported with async engine enabled"
Expand All @@ -190,22 +190,29 @@ def _start_vllm_metrics_logger(self) -> None:

self.inflight_batch_sizes: list[int] = []
self.num_pending_samples: list[int] = []
self.kv_cache_usage_perc: list[float] = []
self.generation_tokens: list[int] = []

def _logger_loop():
# Delay a little to let engine settle
time.sleep(min(2.0, interval_s))
while True:
try:
for m in get_metrics_snapshot():
if isinstance(m, Gauge):
# Log the vllm inflight batch sizes
if m.name == "vllm:num_requests_running":
with self._vllm_metrics_lock:
with self._vllm_metrics_lock:
if isinstance(m, Gauge):
# Log the vllm inflight batch sizes
if m.name == "vllm:num_requests_running":
self.inflight_batch_sizes.append(int(m.value))
# Log the vllm pending number of requests in the queue
elif m.name == "vllm:num_requests_waiting":
with self._vllm_metrics_lock:
# Log the vllm pending number of requests in the queue
elif m.name == "vllm:num_requests_waiting":
self.num_pending_samples.append(int(m.value))
# Log the vllm kv cache usage
elif m.name == "vllm:kv_cache_usage_perc":
self.kv_cache_usage_perc.append(float(m.value))
elif isinstance(m, Counter):
if m.name == "vllm:generation_tokens":
self.generation_tokens.append(int(m.value))
except Exception:
print(
"⚠️[vLLM Metric Logger] Exception in vLLM metrics logger",
Expand All @@ -232,6 +239,8 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
metric = {
"inflight_batch_sizes": copy.deepcopy(self.inflight_batch_sizes),
"num_pending_samples": copy.deepcopy(self.num_pending_samples),
"kv_cache_usage_perc": copy.deepcopy(self.kv_cache_usage_perc),
"generation_tokens": copy.deepcopy(self.generation_tokens),
}
return metric

Expand All @@ -242,6 +251,8 @@ def clear_vllm_logger_metrics(self) -> None:
with self._vllm_metrics_lock:
self.inflight_batch_sizes = []
self.num_pending_samples = []
self.kv_cache_usage_perc = []
self.generation_tokens = []

async def post_init_async(self):
self.vllm_device_ids = await self.report_device_id_async()
Expand Down
82 changes: 82 additions & 0 deletions nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Any, Callable, Mapping, NotRequired, Optional, TypedDict

import mlflow
import numpy as np
import ray
import requests
import swanlab
Expand Down Expand Up @@ -935,6 +936,87 @@ def log_batched_dict_as_jsonl(

print(f"Logged data to {filepath}")

def log_plot_per_worker_timeline_metrics(
self,
metrics: dict[int, list[Any]],
step: int,
prefix: str,
name: str,
timeline_interval: float,
) -> None:
"""Log a plot of per-worker timeline metrics.

Args:
metrics: Dictionary of metrics to log, where the keys are the worker IDs and the values are the lists of metric values
- metrics: dict[str, list[Any]] = {worker_id: [metric_value_1, metric_value_2, ...]}
- metric values are time series values over time, the timing gap between the values is the timeline_interval
step: Global step value
name: Name of the plot
timeline_interval: Interval between timeline points (in seconds)
"""
if not metrics:
print(
f"Skipping {name} per-worker timeline logging because no metrics were provided."
)
return

if timeline_interval <= 0:
raise ValueError(
f"timeline_interval must be positive; received {timeline_interval}"
)

# Plot the per-worker timeline metrics
x_series: list[list[float]] = []
y_series: list[list[float]] = []
series_labels: list[str] = []

if not any(metrics.values()):
print(
f"Skipping {name} per-worker timeline logging because all series were empty."
)
return

for worker_id in sorted(metrics.keys()):
metric_values = metrics[worker_id]
if not metric_values:
continue

x_series.append([i * timeline_interval for i in range(len(metric_values))])
y_series.append([float(v) for v in metric_values])
series_labels.append(f"worker_{worker_id}")

fig, ax = plt.subplots()
for label, xs, ys in zip(series_labels, x_series, y_series):
ax.plot(xs, ys, label=label)

ax.set_xlabel("Time (s)")
ax.set_ylabel(f"{name} (per worker)")
ax.set_title(name)
ax.grid(True, alpha=0.2)
fig.tight_layout()

for logger in self.loggers:
logger.log_plot(fig, step, f"{prefix}/per_worker_{name}")
plt.close(fig)

# Plot the average of the metrics
min_length = min(len(v) for v in metrics.values())
x_series = [i * timeline_interval for i in range(min_length)]
truncated_y_serise = [v[:min_length] for v in y_series]

avg_y_serise = np.mean(truncated_y_serise, axis=0)

fig, ax = plt.subplots()
ax.plot(x_series, avg_y_serise, label="average")
ax.set_xlabel("Time (s)")
ax.set_ylabel(f"{name} (average)")
ax.set_title(name)
ax.grid(True, alpha=0.2)
fig.tight_layout()
for logger in self.loggers:
logger.log_plot(fig, step, f"{prefix}/average_{name}")
plt.close(fig)

def log_plot_token_mult_prob_error(
self, data: dict[str, Any], step: int, name: str
) -> None:
Expand Down
93 changes: 92 additions & 1 deletion tests/unit/utils/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import shutil
import tempfile
from unittest.mock import patch
from unittest.mock import MagicMock, call, patch

import pytest
import torch
Expand Down Expand Up @@ -1741,6 +1741,97 @@ def test_log_hyperparams_with_mlflow(
mock_mlflow_instance.log_hyperparams.assert_called_once_with(params)
mock_swanlab_instance.log_hyperparams.assert_called_once_with(params)

def test_log_plot_per_worker_timeline_metrics_logs_expected_series(self):
"""Ensure per-worker and average plots are produced and logged."""
logger = Logger.__new__(Logger)
backend_logger = MagicMock()
logger.loggers = [backend_logger]

metrics = {
0: [1, 2, 3],
1: [2, 3, 4],
}

mock_fig_worker, mock_ax_worker = MagicMock(), MagicMock()
mock_fig_avg, mock_ax_avg = MagicMock(), MagicMock()

with (
patch(
"nemo_rl.utils.logger.plt.subplots",
side_effect=[
(mock_fig_worker, mock_ax_worker),
(mock_fig_avg, mock_ax_avg),
],
) as mock_subplots,
patch("nemo_rl.utils.logger.plt.close") as mock_close,
):
logger.log_plot_per_worker_timeline_metrics(
metrics,
step=1,
prefix="vllm",
name="kv_cache",
timeline_interval=0.5,
)

assert mock_subplots.call_count == 2
expected_x = [0.0, 0.5, 1.0]
mock_ax_worker.plot.assert_has_calls(
[
call(expected_x, [1.0, 2.0, 3.0], label="worker_0"),
call(expected_x, [2.0, 3.0, 4.0], label="worker_1"),
],
any_order=False,
)

avg_call = mock_ax_avg.plot.call_args_list[0]
assert avg_call.args[0] == expected_x
assert avg_call.args[1].tolist() == [1.5, 2.5, 3.5]
assert avg_call.kwargs["label"] == "average"

backend_logger.log_plot.assert_has_calls(
[
call(mock_fig_worker, 1, "vllm/per_worker_kv_cache"),
call(mock_fig_avg, 1, "vllm/average_kv_cache"),
],
any_order=False,
)
assert mock_close.call_args_list == [
call(mock_fig_worker),
call(mock_fig_avg),
]

def test_log_plot_per_worker_timeline_metrics_requires_positive_interval(self):
"""timeline_interval must be positive."""
logger = Logger.__new__(Logger)
logger.loggers = [MagicMock()]

with pytest.raises(ValueError):
logger.log_plot_per_worker_timeline_metrics(
metrics={0: [1, 2]},
step=1,
prefix="train",
name="pending",
timeline_interval=0.0,
)

def test_log_plot_per_worker_timeline_metrics_skips_when_no_data(self):
"""No plots should be produced when metrics are empty."""
logger = Logger.__new__(Logger)
backend_logger = MagicMock()
logger.loggers = [backend_logger]

with patch("nemo_rl.utils.logger.plt.subplots") as mock_subplots:
logger.log_plot_per_worker_timeline_metrics(
metrics={},
step=1,
prefix="train",
name="pending",
timeline_interval=1.0,
)

mock_subplots.assert_not_called()
backend_logger.log_plot.assert_not_called()


def test_print_message_log_samples(capsys):
"""Test that print_message_log_samples displays full content correctly."""
Expand Down
Loading