Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,20 @@ def grpo_train(
logger,
)

# Plot ISL/OSL/ISL+OSL histograms to wandb
if (
master_config["policy"]["generation"]
.get("vllm_cfg", {})
.get("async_engine", False)
):
for metric_name in metrics.keys():
if metric_name.startswith("histogram/"):
logger.log_histogram(
metrics[metric_name],
total_steps + 1,
f"generation_metrics/{metric_name}",
)

print("\n📊 Training Results:")

print(f" • Loss: {metrics['loss']:.4f}")
Expand Down Expand Up @@ -2489,6 +2503,20 @@ def async_grpo_train(
logger,
)

# Plot ISL/OSL/ISL+OSL histograms to wandb
if (
master_config["policy"]["generation"]
.get("vllm_cfg", {})
.get("async_engine", False)
):
for metric_name in metrics.keys():
if metric_name.startswith("histogram/"):
logger.log_histogram(
metrics[metric_name],
step + 1,
f"generation_metrics/{metric_name}",
)

print("\n📊 Training Results:")
print(f" • Loss: {metrics['loss']:.4f}")
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")
Expand Down
14 changes: 7 additions & 7 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,24 +748,24 @@ def visualize_per_worker_timeline(


def log_generation_metrics_to_wandb(
vllm_logger_metrics: dict[str, dict[int, list[Any]]],
generation_logger_metrics: dict[str, dict[int, list[Any]]],
step: int,
timeline_interval: float,
logger: Logger,
) -> None:
"""Log vLLM metrics to wandb.
"""Log generation metrics to wandb.

Args:
vllm_logger_metrics: Dictionary of vLLM logger metrics
generation_logger_metrics: Dictionary of generation logger metrics
step: Global step value
timeline_interval: Interval between timeline points (in seconds)
logger: Logger instance
"""
for vllm_metric in vllm_logger_metrics.keys():
for generation_metric in generation_logger_metrics.keys():
logger.log_plot_per_worker_timeline_metrics(
vllm_logger_metrics[vllm_metric],
generation_logger_metrics[generation_metric],
step=step,
prefix="vllm_metrics",
name=vllm_metric,
prefix="generation_metrics",
name=generation_metric,
timeline_interval=timeline_interval,
)
17 changes: 17 additions & 0 deletions nemo_rl/experience/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,8 @@ async def run_sample_multi_turn_rollout(

# Track per-turn metrics
turn_gen_tokens = []
turn_input_tokens = []
turn_total_tokens = []
# Track per-turn per-worker token accounting if available
per_worker_token_counts = {} # worker_idx -> token_count

Expand Down Expand Up @@ -685,6 +687,8 @@ async def run_sample_multi_turn_rollout(
assistant_token_count += gen_token_count
token_count += gen_token_count
turn_gen_tokens.append(gen_token_count)
turn_input_tokens.append(int(input_lengths))
turn_total_tokens.append(int(input_lengths) + gen_token_count)
# Per-worker load accounting
if "gen_leader_worker_idx" in gen_metrics:
worker_idx = int(gen_metrics["gen_leader_worker_idx"])
Expand Down Expand Up @@ -770,6 +774,8 @@ async def run_sample_multi_turn_rollout(
"max_turns_reached": max_turns_reached,
"total_reward": total_reward,
"turn_gen_tokens": turn_gen_tokens,
"turn_input_tokens": turn_input_tokens,
"turn_total_tokens": turn_total_tokens,
# Pass-through per-worker per-turn accounting for aggregation at batch level
"per_worker_token_counts": per_worker_token_counts,
}
Expand Down Expand Up @@ -930,6 +936,17 @@ async def run_single_sample_with_error_handling(i, sample_state):
per_worker_token_counts[k] = per_worker_token_counts.get(k, 0) + v
rollout_metrics["per_worker_token_counts"] = per_worker_token_counts

# Collect ISL, OSL, and ISL+OSL metrics for all samples
rollout_metrics["histogram/gen_tokens_length"] = [
t for m in all_sample_metrics for t in m["turn_gen_tokens"]
]
rollout_metrics["histogram/input_tokens_length"] = [
t for m in all_sample_metrics for t in m["turn_input_tokens"]
]
rollout_metrics["histogram/total_tokens_length"] = [
t for m in all_sample_metrics for t in m["turn_total_tokens"]
]

return final_batch, rollout_metrics

return asyncio.run(_async_rollout_implementation())
Expand Down
38 changes: 38 additions & 0 deletions nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def log_hyperparams(self, params: Mapping[str, Any]) -> None:
"""Log dictionary of hyperparameters."""
pass

@abstractmethod
def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
"""Log histogram metrics."""
pass


class TensorboardLogger(LoggerInterface):
"""Tensorboard logger backend."""
Expand Down Expand Up @@ -153,6 +158,10 @@ def log_metrics(
print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}")
continue

def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
"""Log histogram metrics to Tensorboard."""
return

def log_hyperparams(self, params: Mapping[str, Any]) -> None:
"""Log hyperparameters to Tensorboard.

Expand Down Expand Up @@ -350,6 +359,16 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
"""
self.run.log({name: figure}, step=step)

def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
"""Log histogram metrics to wandb.

Args:
histogram: List of histogram values
step: Global step value
name: Name of the metric
"""
self.run.log({name: wandb.Histogram(histogram)}, step=step)


class SwanlabLogger(LoggerInterface):
"""SwanLab logger backend."""
Expand Down Expand Up @@ -419,6 +438,10 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
"""
self.run.log({name: figure}, step=step)

def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
"""Log histogram metrics to swanlab."""
return


class GpuMetricSnapshot(TypedDict):
step: int
Expand Down Expand Up @@ -793,6 +816,10 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
figure.savefig(tmp_file.name, format="png", bbox_inches="tight")
mlflow.log_artifact(tmp_file.name, f"plots/{name}")

def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
"""Log histogram metrics to MLflow."""
return

def __del__(self) -> None:
"""Clean up resources when the logger is destroyed."""
try:
Expand Down Expand Up @@ -1017,6 +1044,17 @@ def log_plot_per_worker_timeline_metrics(
logger.log_plot(fig, step, f"{prefix}/average_{name}")
plt.close(fig)

def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
"""Log histogram metrics to all backends if available.

Args:
histogram: List of histogram values
step: Global step value
name: Name of the metric
"""
for logger in self.loggers:
logger.log_histogram(histogram, step, name)

def log_plot_token_mult_prob_error(
self, data: dict[str, Any], step: int, name: str
) -> None:
Expand Down
Loading